Unverified Commit b5f05663 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[chat] add distributed PPO trainer (#3740)



* Detached ppo (#9)

* run the base

* working on dist ppo

* sync

* detached trainer

* update detached trainer. no maker update function

* facing init problem

* 1 maker 1 trainer detached run. but no model update

* facing cuda problem

* fix save functions

* verified maker update

* nothing

* add ignore

* analyize loss issue

* remove some debug codes

* facing 2m1t stuck issue

* 2m1t verified

* do not use torchrun

* working on 2m2t

* working on 2m2t

* initialize strategy in ray actor env

* facing actor's init order issue

* facing ddp model update issue (need unwarp ddp)

* unwrap ddp actor

* checking 1m2t stuck problem

* nothing

* set timeout for trainer choosing. It solves the stuck problem!

* delete some debug output

* rename to sync with upstream

* rename to sync with upstream

* coati rename

* nothing

* I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations

* experience_maker_holder performs target-revolving _send_experience() instead of length comparison.

* move code to ray subfolder

* working on pipeline inference

* apply comments

* working on pipeline strategy. in progress.

* remove pipeline code. clean this branch

* update remote parameters by state_dict. no test

* nothing

* state_dict sharding transfer

* merge debug branch

* gemini _unwrap_model fix

* simplify code

* simplify code & fix LoRALinear AttributeError

* critic unwrapped state_dict

---------
Co-authored-by: default avatarcsric <richcsr256@gmail.com>

* [chat] add perfomance evaluator and fix bugs (#10)

* [chat] add performance evaluator for ray

* [chat] refactor debug arg

* [chat] support hf config

* [chat] fix generation

* [chat] add 1mmt dummy example

* [chat] fix gemini ckpt

* split experience to send (#11)
Co-authored-by: default avatarcsric <richcsr256@gmail.com>

* [chat] refactor trainer and maker (#12)

* [chat] refactor experience maker holder

* [chat] refactor model init

* [chat] refactor trainer args

* [chat] refactor model init

* [chat] refactor trainer

* [chat] refactor experience sending logic and training loop args (#13)

* [chat] refactor experience send logic

* [chat] refactor trainer

* [chat] refactor trainer

* [chat] refactor experience maker

* [chat] refactor pbar

* [chat] refactor example folder (#14)

* [chat] support quant (#15)

* [chat] add quant

* [chat] add quant example

* prompt example (#16)

* prompt example

* prompt load csv data

* remove legacy try

---------
Co-authored-by: default avatarcsric <richcsr256@gmail.com>

* [chat] add mmmt dummy example and refactor experience sending (#17)

* [chat] add mmmt dummy example

* [chat] refactor naive strategy

* [chat] fix struck problem

* [chat] fix naive strategy

* [chat] optimize experience maker sending logic

* [chat] refactor sending assignment

* [chat] refactor performance evaluator (#18)

* Prompt Example & requires_grad state_dict & sharding state_dict (#19)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

---------
Co-authored-by: default avatarcsric <richcsr256@gmail.com>

* state_dict sending adapts to new unwrap function (#20)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

* opt benchmark

* better script

* nothing

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test

* working on lora reconstruction

* state_dict sending adapts to new unwrap function

* remove comments

---------
Co-authored-by: default avatarcsric <richcsr256@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* [chat-ray] add readme (#21)

* add readme

* transparent graph

* add note background

---------
Co-authored-by: default avatarcsric <richcsr256@gmail.com>

* [chat] get images from url (#22)

* Refactor/chat ray (#23)

* [chat] lora add todo

* [chat] remove unused pipeline strategy

* [chat] refactor example structure

* [chat] setup ci for ray

* [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24)

* lora support prototype

* lora support

* 1mmt lora & remove useless code

---------
Co-authored-by: default avatarcsric <richcsr256@gmail.com>

* [chat] fix test ci for ray

* [chat] fix test ci requirements for ray

* [chat] fix ray runtime env

* [chat] fix ray runtime env

* [chat] fix example ci docker args

* [chat] add debug info in trainer

* [chat] add nccl debug info

* [chat] skip ray test

* [doc] fix typo

---------
Co-authored-by: default avatarcsric <59389055+CsRic@users.noreply.github.com>
Co-authored-by: default avatarcsric <richcsr256@gmail.com>
parent 41fb7236
......@@ -20,7 +20,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat
options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb
timeout-minutes: 30
defaults:
run:
......
import argparse
import os
import socket
from functools import partial
import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
get_actor_from_args,
get_critic_from_args,
get_receivers_per_sender,
get_reward_model_from_args,
get_strategy_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(args.num_trainers)]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
'local_rank': '0',
'rank': '0',
'world_size': '1',
'master_port': maker_port,
'master_addr': master_addr
}
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
def model_fn():
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.critic_model,
config=critic_cfg).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
args.quant_group_size).cuda().requires_grad_(False)
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
kl_coef=0.1,
debug=args.debug,
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
)
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.critic_model,
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
) for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
return dataloader
# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])
wait_tasks = []
wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
num_steps=args.experience_steps))
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_trainers', type=int, default=1)
parser.add_argument('--trainer_strategy',
choices=[
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
'colossalai_zero2_cpu'
],
default='naive')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4)
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
import argparse
import os
import socket
from functools import partial
import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
get_actor_from_args,
get_critic_from_args,
get_receivers_per_sender,
get_reward_model_from_args,
get_strategy_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(args.num_trainers)]
# maker_env_info
maker_port = str(get_free_port())
env_info_makers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_makers),
'master_port': maker_port,
'master_addr': master_addr
} for rank in range(args.num_makers)]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
def model_fn():
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.critic_model,
config=critic_cfg).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
args.quant_group_size).cuda().requires_grad_(False)
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
f'trainer{x}'
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
kl_coef=0.1,
debug=args.debug,
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
)
for i, env_info_maker in enumerate(env_info_makers)
]
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.critic_model,
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
f"maker{x}"
for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
return dataloader
# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])
wait_tasks = []
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
num_steps=args.experience_steps))
total_steps = args.experience_batch_size * args.experience_steps * \
args.num_makers // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_makers', type=int, default=1)
parser.add_argument('--num_trainers', type=int, default=1)
parser.add_argument('--trainer_strategy',
choices=[
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
'colossalai_zero2_cpu'
],
default='naive')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4)
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
......@@ -61,6 +61,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
......
from .llama_gptq import load_quant as llama_load_quant
from .utils import low_resource_init
__all__ = [
'llama_load_quant',
'low_resource_init',
]
from .loader import load_quant
__all__ = [
'load_quant',
]
import torch
import torch.nn as nn
from .model_utils import find_layers
from .quant import make_quant
def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
model = model.eval()
layers = find_layers(model)
# ignore lm head
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
return model
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
import torch
import torch.nn as nn
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py
import math
import numpy as np
import torch
import torch.nn as nn
def quantize(x, scale, zero, maxq):
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
try:
import quant_cuda
except:
print('CUDA extension not installed.')
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))):
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
dtype=torch.int))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer('bias', torch.zeros(outfeatures))
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone()
if linear.bias is not None:
self.bias = linear.bias.clone()
intweight = []
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
elif self.bits == 3:
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i))
i += 10
qweight[row] |= intweight[i] << 30
row += 1
qweight[row] |= (intweight[i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
i += 10
qweight[row] |= intweight[i] << 31
row += 1
qweight[row] |= (intweight[i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
i += 10
row += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
elif self.bits == 3:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
intermediate_dtype = torch.float32
if not self._initialized_quant_state:
# Do we even have a bias? Check for at least one non-zero element.
if self.bias is not None and bool(torch.any(self.bias != 0)):
# Then make sure it's the right type.
self.bias.data = self.bias.data.to(intermediate_dtype)
else:
self.bias = None
outshape = list(x.shape)
outshape[-1] = self.outfeatures
x = x.reshape(-1, x.shape[-1])
if self.bias is None:
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
else:
y = self.bias.clone().repeat(x.shape[0], 1)
output_dtype = x.dtype
x = x.to(intermediate_dtype)
if self.bits == 2:
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 3:
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 4:
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 8:
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
y = y.to(output_dtype)
return y.reshape(outshape)
def make_quant(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
from contextlib import contextmanager
import torch
def _noop(*args, **kwargs):
pass
@contextmanager
def low_resource_init():
"""This context manager disables weight initialization and sets the default float dtype to half.
"""
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
old_uniform_ = torch.nn.init.uniform_
old_normal_ = torch.nn.init.normal_
dtype = torch.get_default_dtype()
try:
torch.nn.init.kaiming_uniform_ = _noop
torch.nn.init.uniform_ = _noop
torch.nn.init.normal_ = _noop
torch.set_default_dtype(torch.half)
yield
finally:
torch.nn.init.kaiming_uniform_ = old_kaiming_uniform_
torch.nn.init.uniform_ = old_uniform_
torch.nn.init.normal_ = old_normal_
torch.set_default_dtype(dtype)
# Distributed PPO Training on Stage 3
## Detach Experience Makers and Trainers
We can completely separate the trainers and makers.
<p align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/basic_structure.png?raw=true" width=600/>
</p>
- The experience maker performs inference, produces experience, and remotely delivers it to the trainer (1).
- The trainer consumes experience to train models, and periodically transmits new model parameters to the maker (2.1, 2.2).
- Using an experience buffer to overlap transmission and computing.
In this manner, each node will work continuously without model idle time, and different optimization strategies can be applied for inference and training to meet the needs of speed or storage. It is also helpful for scalability.
`DetachedPPOTrainer` and `ExperienceMakerHolder` are Ray Actors (distinguished from Actor Model), representing Trainer and Experience Maker on the graph above, respectively.
[More about Ray Core](https://docs.ray.io/en/latest/ray-core/walkthrough.html)
## Usage
See examples at `ColossalAI/application/Chat/examples/ray`
### Setup Makers
- define makers' environment variables :
```python
env_info_makers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(num_makers),
'master_port': maker_port,
'master_addr': master_addr
} for rank in range(num_makers)]
```
- define maker models :
```python
def model_fn():
actor = get_actor_from_args(...)
critic = get_critic_from_args(...)
reward_model = get_reward_model_from_args(...)
initial_model = get_actor_from_args(...)
return actor, critic, reward_model, initial_model
```
- set experience_holder_refs :
```python
experience_holder_refs = [
ExperienceMakerHolder.options(
name=f"maker_{i}",
num_gpus=1,
max_concurrency=2
).remote(
detached_trainer_name_list=[f"trainer_{x}" for x in target_trainers(...)],
model_fn=model_fn,
...)
for i, env_info_maker in enumerate(env_info_makers)
]
```
The names in the `detached_trainer_name_list` refer to the target trainers that the maker should send experience to.
We set a trainer's name the same as a maker, by `.options(name="str")`. See below.
### Setup Trainers
- define trainers' environment variables :
```python
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(num_trainers)]
```
- define trainer models :
```python
def trainer_model_fn():
actor = get_actor_from_args(...)
critic = get_critic_from_args(...)
return actor, critic
```
- set trainer_refs :
```python
trainer_refs = [
DetachedPPOTrainer.options(
name=f"trainer{i}",
num_gpus=1,
max_concurrency=2
).remote(
experience_maker_holder_name_list=[f"maker{x}" for x in target_makers(...)],
model_fn = trainer_model_fn(),
...)
for i, env_info_trainer in enumerate(env_info_trainers)
]
```
The names in `experience_maker_holder_name_list` refer to the target makers that the trainer should send updated models to.
By setting `detached_trainer_name_list` and `experience_maker_holder_name_list`, we can customize the transmission graph.
### Launch Jobs
- define data_loader :
```python
def data_loader_fn():
return = torch.utils.data.DataLoader(dataset=dataset)
```
- launch makers :
```python
wait_tasks = []
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(
experience_holder_ref.workingloop.remote(data_loader_fn(),
num_steps=experience_steps))
```
- launch trainers :
```python
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, update_steps, train_epochs))
```
- wait for done :
```python
ray.get(wait_tasks)
```
## Flexible Structure
We can deploy different strategies to makers and trainers. Here are some notions.
### 2 Makers 1 Trainer
<p align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m1t.png?raw=true" width=600/>
</p>
### 2 Makers 2 Trainer
<p align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m2t.png?raw=true" width=600/>
</p>
### Maker Inference Quantization
<p align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m2t_quantize.png?raw=true" width=600/>
</p>
### Tensor Parallel
<p align="center">
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/tp_ddp_hybrid.png?raw=true" width=600/>
</p>
## TODO
- [ ] Support LoRA
- [ ] Support TP & PP
from .src.detached_replay_buffer import DetachedReplayBuffer
from .src.detached_trainer_ppo import DetachedPPOTrainer
from .base import MakerCallback, TrainerCallback
from .performance_evaluator import ExperienceMakerPerformanceEvaluator, TrainerPerformanceEvaluator
__all__ = [
"TrainerCallback",
"MakerCallback",
"ExperienceMakerPerformanceEvaluator",
"TrainerPerformanceEvaluator",
]
from abc import ABC
from coati.experience_maker import Experience
class TrainerCallback(ABC):
"""
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
pass
def on_fit_end(self) -> None:
pass
def on_episode_start(self, episode: int) -> None:
pass
def on_episode_end(self, episode: int) -> None:
pass
def on_epoch_start(self, epoch: int) -> None:
pass
def on_epoch_end(self, epoch: int) -> None:
pass
def on_batch_start(self) -> None:
pass
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
pass
def on_update_start(self) -> None:
pass
def on_update_end(self) -> None:
pass
class MakerCallback(ABC):
def on_loop_start(self) -> None:
pass
def on_loop_end(self) -> None:
pass
def on_make_experience_start(self) -> None:
pass
def on_make_experience_end(self, experience: Experience) -> None:
pass
def on_send_start(self) -> None:
pass
def on_send_end(self) -> None:
pass
def on_batch_start(self) -> None:
pass
def on_batch_end(self) -> None:
pass
from time import time
from typing import Optional
import torch
import torch.distributed as dist
from coati.experience_maker import Experience
from .base import MakerCallback, TrainerCallback
def get_world_size() -> int:
if dist.is_initialized():
return dist.get_world_size()
return 1
def print_rank_0(*args, **kwargs) -> None:
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
self.duration += time() - self.start_time
def reset(self) -> None:
self.duration = 0.
class ExperienceMakerPerformanceEvaluator(MakerCallback):
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
reward_model_num_params: int) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
self.critic_num_params = critic_num_params
self.initial_model_num_params = initial_model_num_params
self.reward_model_num_params = reward_model_num_params
self.batch_timer = Timer()
self.send_timer = Timer()
self.make_experience_timer = Timer()
self.total_samples: int = 0
self.make_experience_flop: int = 0
print_rank_0(
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
)
def on_make_experience_start(self) -> None:
self.make_experience_timer.start()
def on_make_experience_end(self, experience: Experience) -> None:
self.make_experience_timer.end()
batch_size, seq_len = experience.sequences.shape
self.total_samples += batch_size
# actor generate
num_actions = experience.action_mask.size(1)
input_len = seq_len - num_actions
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
# actor forward
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
# critic forward
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
# initial model forward
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
# reward model forward
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
def on_send_start(self) -> None:
self.send_timer.start()
def on_send_end(self) -> None:
self.send_timer.end()
def on_batch_start(self) -> None:
self.batch_timer.start()
def on_batch_end(self) -> None:
self.batch_timer.end()
def on_loop_end(self) -> None:
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size)
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
(self.total_samples * self.world_size)
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
class TrainerPerformanceEvaluator(TrainerCallback):
def __init__(self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
self.critic_num_params = critic_num_params
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_first_episodes = ignore_first_episodes
self.ignore_this_episode = False
self.episode_timer = Timer()
self.batch_timer = Timer()
self.update_timer = Timer()
self.total_samples: int = 0
self.learn_flop: int = 0
print_rank_0(
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
)
def on_episode_start(self, episodes: int) -> None:
self.ignore_this_episode = episodes < self.ignore_first_episodes
if self.ignore_this_episode:
return
self.episode_timer.start()
def on_episode_end(self, episodes: int) -> None:
if self.ignore_this_episode:
return
self.episode_timer.end()
def on_batch_start(self) -> None:
if self.ignore_this_episode:
return
self.batch_timer.start()
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
if self.ignore_this_episode:
return
self.batch_timer.end()
batch_size, seq_len = experience.sequences.shape
self.total_samples += batch_size
# actor forward-backward, 3 means forward(1) + backward(2)
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
# critic forward-backward
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
def on_update_start(self) -> None:
if self.ignore_this_episode:
return
self.update_timer.start()
def on_update_end(self) -> None:
if self.ignore_this_episode:
return
self.update_timer.end()
def on_fit_end(self) -> None:
if self.total_samples == 0:
print_rank_0('No samples are collected, skip trainer performance evaluation')
return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size)
avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12)
avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12)
avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size)
avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size)
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
import torch
import asyncio
import copy
import random
from typing import List, Any
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
from threading import Lock
from typing import Any, List
import ray
import asyncio
import torch
from coati.experience_maker.base import Experience
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.replay_buffer import ReplayBuffer
from threading import Lock
import copy
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
class DetachedReplayBuffer:
'''
......@@ -24,31 +26,25 @@ class DetachedReplayBuffer:
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
'''
def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None:
self.cpu_offload = cpu_offload
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
self.limit = limit
self.items = Queue(self.limit, actor_options={"num_cpus":1})
self.batch_collector : List[BufferItem] = []
self.items = Queue(self.limit, actor_options={"num_cpus": 1})
self.batch_collector: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
'''
Workers in the same tp group share this buffer and need same sample for one step.
Therefore a held_sample should be returned tp_world_size times before it could be dropped.
worker_state records whether a worker got the held_sample
Expected to be called remotely.
'''
self.tp_world_size = tp_world_size
self.worker_state = [False] * self.tp_world_size
self.held_sample = None
self._worker_state_lock = Lock()
items = split_experience_batch(experience)
self.extend(items)
@torch.no_grad()
def append(self, experience: Experience) -> None:
def extend(self, items: List[BufferItem]) -> None:
'''
Expected to be called remotely.
'''
if self.cpu_offload:
experience.to_device(torch.device('cpu'))
items = split_experience_batch(experience)
self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size:
items = self.batch_collector[:self.sample_batch_size]
......@@ -64,17 +60,8 @@ class DetachedReplayBuffer:
self.batch_collector = []
@torch.no_grad()
def sample(self, worker_rank = 0, to_device = "cpu") -> Experience:
self._worker_state_lock.acquire()
if not any(self.worker_state):
self.held_sample = self._sample_and_erase()
self.worker_state[worker_rank] = True
if all(self.worker_state):
self.worker_state = [False] * self.tp_world_size
ret = self.held_sample
else:
ret = copy.deepcopy(self.held_sample)
self._worker_state_lock.release()
def sample(self, worker_rank=0, to_device="cpu") -> Experience:
ret = self._sample_and_erase()
ret.to_device(to_device)
return ret
......
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from tqdm import tqdm
from coati.trainer.callbacks import Callback
from coati.experience_maker import Experience
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import ray
import os
import torch
from coati.experience_maker import Experience
from coati.replay_buffer.utils import BufferItem
from torch.utils.data import DataLoader
from tqdm import tqdm
from .callbacks import TrainerCallback
from .detached_replay_buffer import DetachedReplayBuffer
from .utils import is_rank_0
class DetachedTrainer(ABC):
'''
Base class for detached rlhf trainers.
......@@ -19,87 +24,116 @@ class DetachedTrainer(ABC):
Args:
detached_strategy (DetachedStrategy): the strategy to use for training
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
def __init__(self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
experience_batch_size: int = 8,
max_epochs: int = 1,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
callbacks: List[TrainerCallback] = [],
debug: bool = False) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload)
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
self.generate_kwargs = generate_kwargs
self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = []
self._is_target_holder_initialized = False
self._debug = debug
def update_target_holder_list(self, experience_maker_holder_name_list):
self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = []
def update_target_holder_list(self):
# as the length of target_holder_list may be zero, we need to check it by a bool flag
if not self._is_target_holder_initialized:
for name in self.target_holder_name_list:
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
self._is_target_holder_initialized = True
@abstractmethod
def _update_remote_makers(self):
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
pass
def sync_models_to_remote_makers(self, **kwargs):
self._update_remote_makers(fully_update=True, **kwargs)
@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass
def _learn(self):
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print("[trainer] sampling exp")
experience = self._buffer_sample()
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
self._on_epoch_start(0)
self._learn_epoch(pbar, data)
self._on_epoch_end(0)
# item is already a batch
dataloader = DataLoader(data,
batch_size=1,
shuffle=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
is_warmup = len(data) == 0
for x in pbar:
if self._debug:
print("[trainer] training step")
# sample a batch and then train to avoid waiting
experience = x if not is_warmup else self._buffer_sample()
experience.to_device(torch.cuda.current_device())
self._on_batch_start()
metrics = self.training_step(experience)
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
self._on_batch_end(metrics, experience)
if self._debug:
print("[trainer] step over")
experience.to_device("cpu")
if is_warmup:
data.append(experience)
pbar.set_postfix(metrics)
def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps // update_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
self._learn()
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
self._on_episode_start(i)
self._learn(update_steps, train_epochs)
self._on_update_start()
self._update_remote_makers()
self._on_episode_end(episode)
self._on_update_end()
self._on_episode_end(i)
self._on_fit_end()
@ray.method(concurrency_group="buffer_length")
def buffer_get_length(self):
# called by ExperienceMakerHolder
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
if self._debug:
print("[trainer] telling length")
return self.detached_replay_buffer.get_length()
@ray.method(concurrency_group="buffer_append")
def buffer_append(self, experience: Experience):
# called by ExperienceMakerHolder
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
# print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}")
if self._debug:
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.append(experience)
@ray.method(concurrency_group="buffer_append")
def buffer_extend(self, items: List[BufferItem]):
# called by ExperienceMakerHolder
if self._debug:
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.extend(items)
@ray.method(concurrency_group="buffer_sample")
def _buffer_sample(self):
return self.detached_replay_buffer.sample()
......@@ -119,3 +153,27 @@ class DetachedTrainer(ABC):
def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_epoch_start(epoch)
def _on_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def _on_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_batch_start()
def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_batch_end(metrics, experience)
def _on_update_start(self) -> None:
for callback in self.callbacks:
callback.on_update_start()
def _on_update_end(self) -> None:
for callback in self.callbacks:
callback.on_update_end()
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.optim import Adam
from typing import Any, Callable, Dict, List, Optional, Tuple
import ray
import torch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.generation_utils import update_model_kwargs_fn
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
import ray
from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1})
from .lora_constructor import LoRAConstructor
from .utils import (
get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
state_dict_to,
)
@ray.remote(concurrency_groups={
"buffer_length": 1,
"buffer_append": 1,
"buffer_sample": 1,
"model_io": 1,
"compute": 1
})
class DetachedPPOTrainer(DetachedTrainer):
'''
Detached Trainer for PPO algorithm
......@@ -40,86 +54,102 @@ class DetachedPPOTrainer(DetachedTrainer):
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
def __init__(self,
def __init__(
self,
experience_maker_holder_name_list: List[str],
strategy: str,
model: str,
strategy_fn: Callable[[], Strategy],
model_fn: Callable[[], Tuple[Actor, Critic]],
env_info: Dict[str, str] = None,
pretrained: str = None,
lora_rank: int = 0,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.4,
experience_batch_size: int = 8,
max_epochs: int = 10,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
callbacks: List[TrainerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
) -> None:
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
# configure strategy
self.strategy = get_strategy_from_args(strategy)
self.strategy = strategy_fn()
# configure models, loss and optimizers
with self.strategy.model_init_context():
self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank)
self.actor, self.critic = model_fn()
if strategy != 'colossalai_gemini':
self.actor.to(torch.float16).to(torch.cuda.current_device())
self.critic.to(torch.float16).to(torch.cuda.current_device())
if eval_performance:
actor_numel = get_model_numel(self.actor)
critic_numel = get_model_numel(self.critic)
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
callbacks = callbacks + [evaluator]
if strategy.startswith('colossalai'):
self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6)
if isinstance(self.strategy, ColossalAIStrategy):
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
else:
self.actor_optim = Adam(self.actor.parameters(), lr=5e-6)
self.critic_optim = Adam(self.critic.parameters(), lr=5e-6)
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
buffer_cpu_offload=buffer_cpu_offload,
experience_batch_size=experience_batch_size,
max_epochs=max_epochs,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
**generate_kwargs)
debug=debug)
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
self._update_lora_weights = update_lora_weights
@ray.method(concurrency_group="model_io")
def _update_remote_makers(self):
@torch.no_grad()
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
if not fully_update:
config['requires_grad_only'] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
for target_holder in self.target_holder_list:
# TODO: reduce malloc
with torch.no_grad():
ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
ray.get(tasks)
# sending loop
tasks = []
@ray.method(concurrency_group="model_io")
def initialize_remote_makers(self):
# TODO: balance duties
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
# TODO: reduce malloc
with torch.no_grad():
ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
tasks.append(
target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
fully_update=fully_update))
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update))
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
@ray.method(concurrency_group="compute")
def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
self.critic.train()
experience.to_device(torch.cuda.current_device())
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
......@@ -155,38 +185,16 @@ class DetachedPPOTrainer(DetachedTrainer):
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
def _get_unwrapped_actor(self):
if False:
pass
elif isinstance(self.strategy, ColossalAIStrategy):
ret = Actor(self.strategy._unwrap_model(self.actor))
return ret
elif isinstance(self.strategy, DDPStrategy):
return Actor(self.strategy._unwrap_actor(self.actor))
elif isinstance(self.strategy, NaiveStrategy):
return self.actor
def _get_unwrapped_critic(self):
if False:
pass
elif isinstance(self.strategy, ColossalAIStrategy):
ret = self.strategy._unwrap_model(self.critic)
return ret
elif isinstance(self.strategy, DDPStrategy):
return self.critic.module
elif isinstance(self.strategy, NaiveStrategy):
return self.critic
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
return new_kwargs
\ No newline at end of file
def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
if not self._update_lora_weights or fully_update:
yield state_dict_to(state_dict)
else:
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
yield state_dict_to(state_dict_lora)
def _get_model_lora_config_dict(self, model: torch.nn.Module):
if not self._update_lora_weights:
return None
unwrapped_model = self.strategy.unwrap_model(model)
return LoRAConstructor.extract_lora_config(unwrapped_model)
import argparse
from copy import deepcopy
import pandas as pd
import torch
from coati.trainer import PPOTrainer
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.experience_maker import NaiveExperienceMaker
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
import ray
import os
import socket
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainer = {'local_rank' : '0',
'rank' : '0',
'world_size' : '1',
'master_port' : trainer_port,
'master_addr' : master_addr}
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {'local_rank' : '0',
'rank' : '0',
'world_size' : '1',
'master_port' : maker_port,
'master_addr' : master_addr}
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure Trainer
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy=args.trainer_strategy,
model=args.model,
env_info = env_info_trainer,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1"],
strategy=args.maker_strategy,
env_info = env_info_maker,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# trainer send its actor and critic to experience holders.
ray.get(trainer_ref.initialize_remote_makers.remote())
# configure sampler
dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()}
trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
ray.get([trainer_done_ref, maker_done_ref])
# save model checkpoint after fitting
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prompt_path')
parser.add_argument('--trainer_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--maker_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"])
main(args)
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
export RAY_NAMESPACE="admin"
python 1m1t.py "/path/to/prompts.csv" \
--trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
--max_epochs 10 --debug
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment