"lib/runtime/src/vscode:/vscode.git/clone" did not exist on "096699c4b6049f7181c025f002ee9974dc998b7d"
Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
import argparse import argparse
import os import os
import loralib as lora
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMLM
from coati.models.gpt import GPTLM
from coati.models.llama import LlamaLM
from coati.models.opt import OPTLM
from coati.trainer import SFTTrainer from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from easy_dataset import EasyDataset from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from torch.optim import Adam from torch.optim import Adam
...@@ -29,75 +21,76 @@ from colossalai.tensor import ColoParameter ...@@ -29,75 +21,76 @@ from colossalai.tensor import ColoParameter
def train(args): def train(args):
# configure strategy # configure strategy
if args.strategy == 'ddp': if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini': elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy='cuda') strategy = GeminiStrategy(placement_policy="cuda")
elif args.strategy == 'colossalai_zero2': elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model # configure model
with strategy.model_init_context(): with strategy.model_init_context():
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json # if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \ if (
and os.path.exists(args.save_path + '/adapter_model.bin'): os.path.exists(args.save_path)
and os.path.exists(args.save_path + "/adapter_config.json")
and os.path.exists(args.save_path + "/adapter_model.bin")
):
print("loading from saved peft model ", args.save_path) print("loading from saved peft model ", args.save_path)
model = PeftModel.from_pretrained(model, args.save_path) model = PeftModel.from_pretrained(model, args.save_path)
else: else:
# we'll use peft lora library to do the lora # we'll use peft lora library to do the lora
lora_rank = args.lora_rank if args.lora_rank > 0 else 32 lora_rank = args.lora_rank if args.lora_rank > 0 else 32
# config lora with rank of lora_rank # config lora with rank of lora_rank
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, lora_config = LoraConfig(
inference_mode=False, task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
r=lora_rank, )
lora_alpha=32,
lora_dropout=0.1)
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
model.print_trainable_parameters() model.print_trainable_parameters()
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == "llama":
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.pretrain, args.pretrain,
padding_side="right", padding_side="right",
use_fast=False, use_fast=False,
) )
tokenizer.eos_token = '<\s>' tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
if args.model == 'llama' and args.strategy == 'colossalai_gemini': if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding # this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not isinstance(param, ColoParameter): if not isinstance(param, ColoParameter):
sub_module_name = '.'.join(name.split('.')[:-1]) sub_module_name = ".".join(name.split(".")[:-1])
weight_name = name.split('.')[-1] weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name) sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param)) setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else: else:
optim = Adam(model.parameters(), lr=args.lr) optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger() logger = get_dist_logger()
logger.set_level('WARNING') logger.set_level("WARNING")
# configure dataset # configure dataset
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
...@@ -108,47 +101,57 @@ def train(args): ...@@ -108,47 +101,57 @@ def train(args):
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
data_collator = default_collate data_collator = default_collate
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, train_sampler = DistributedSampler(
shuffle=True, train_dataset,
seed=42, shuffle=True,
drop_last=True, seed=42,
rank=dist.get_rank(), drop_last=True,
num_replicas=dist.get_world_size()) rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
if eval_dataset is not None: if eval_dataset is not None:
eval_sampler = DistributedSampler(eval_dataset, eval_sampler = DistributedSampler(
shuffle=False, eval_dataset,
seed=42, shuffle=False,
drop_last=False, seed=42,
rank=dist.get_rank(), drop_last=False,
num_replicas=dist.get_world_size()) rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
else: else:
train_sampler = None train_sampler = None
eval_sampler = None eval_sampler = None
train_dataloader = DataLoader(train_dataset, train_dataloader = DataLoader(
shuffle=(train_sampler is None), train_dataset,
sampler=train_sampler, shuffle=(train_sampler is None),
batch_size=args.batch_size, sampler=train_sampler,
collate_fn=data_collator, batch_size=args.batch_size,
pin_memory=True) collate_fn=data_collator,
pin_memory=True,
)
if eval_dataset is not None: if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset, eval_dataloader = DataLoader(
shuffle=(eval_sampler is None), eval_dataset,
sampler=eval_sampler, shuffle=(eval_sampler is None),
batch_size=args.batch_size, sampler=eval_sampler,
collate_fn=data_collator, batch_size=args.batch_size,
pin_memory=True) collate_fn=data_collator,
pin_memory=True,
)
else: else:
eval_dataloader = None eval_dataloader = None
trainer = SFTTrainer(model=model, trainer = SFTTrainer(
strategy=strategy, model=model,
optim=optim, strategy=strategy,
train_dataloader=train_dataloader, optim=optim,
eval_dataloader=eval_dataloader, train_dataloader=train_dataloader,
batch_size=args.batch_size, eval_dataloader=eval_dataloader,
max_epochs=args.max_epochs, batch_size=args.batch_size,
accumulation_steps=args.accumulation_steps) max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
)
trainer.fit(logger=logger, log_interval=args.log_interval) trainer.fit(logger=logger, log_interval=args.log_interval)
...@@ -156,29 +159,27 @@ def train(args): ...@@ -156,29 +159,27 @@ def train(args):
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer, strategy.save_optimizer(
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
only_rank0=False) )
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--strategy', parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
default='ddp') parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument("--dataset", type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--eval_dataset", type=str, default=None)
parser.add_argument('--dataset', type=str, default=None) parser.add_argument("--save_path", type=str, default="output")
parser.add_argument('--eval_dataset', type=str, default=None) parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument('--save_path', type=str, default='output') parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument('--max_epochs', type=int, default=3) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--batch_size', type=int, default=4) parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument('--lr', type=float, default=5e-6) parser.add_argument("--enable_peft_lora", action="store_true", default=False)
parser.add_argument('--accumulation_steps', type=int, default=8) parser.add_argument("--is_short_text", action="store_true", default=False)
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
parser.add_argument("--is_short_text", action='store_true', default=False)
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)
...@@ -6,16 +6,25 @@ from ray.job_submission import JobSubmissionClient ...@@ -6,16 +6,25 @@ from ray.job_submission import JobSubmissionClient
def main(api_server_endpoint="http://127.0.0.1:8265"): def main(api_server_endpoint="http://127.0.0.1:8265"):
client = JobSubmissionClient(api_server_endpoint) client = JobSubmissionClient(api_server_endpoint)
client.submit_job( client.submit_job(
entrypoint= entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
"python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
runtime_env={ runtime_env={
"working_dir": "working_dir": "applications/Chat",
"applications/Chat",
"pip": [ "pip": [
"torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain", "torch==1.13.1",
"tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat" "transformers>=4.20.1",
] "datasets",
}) "loralib",
"colossalai>=0.2.4",
"langchain",
"tokenizers",
"fastapi",
"sse_starlette",
"wandb",
"sentencepiece",
"gpustat",
],
},
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -26,9 +26,14 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -26,9 +26,14 @@ from colossalai.nn.optimizer import HybridAdam
class ExperienceCompositionRefs: class ExperienceCompositionRefs:
def __init__(
def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef, self,
base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None: sequences_attention_mask_action_mask_ref: ray.ObjectRef,
action_log_probs_ref: ray.ObjectRef,
base_action_log_probs_ref: ray.ObjectRef,
value_ref: ray.ObjectRef,
r_ref: ray.ObjectRef,
) -> None:
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
self.action_log_probs_ref = action_log_probs_ref self.action_log_probs_ref = action_log_probs_ref
self.base_action_log_probs_ref = base_action_log_probs_ref self.base_action_log_probs_ref = base_action_log_probs_ref
...@@ -37,14 +42,14 @@ class ExperienceCompositionRefs: ...@@ -37,14 +42,14 @@ class ExperienceCompositionRefs:
class ExperienceMaker: class ExperienceMaker:
def __init__(self, kl_coef) -> None: def __init__(self, kl_coef) -> None:
self.kl_coef = kl_coef self.kl_coef = kl_coef
@torch.no_grad() @torch.no_grad()
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs): def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
sequences, attention_mask, action_mask = ray.get( sequences, attention_mask, action_mask = ray.get(
experiment_computation_refs.sequences_attention_mask_action_mask_ref) experiment_computation_refs.sequences_attention_mask_action_mask_ref
)
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref) action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref) base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
r = ray.get(experiment_computation_refs.r_ref) r = ray.get(experiment_computation_refs.r_ref)
...@@ -58,11 +63,10 @@ class ExperienceMaker: ...@@ -58,11 +63,10 @@ class ExperienceMaker:
class DistributedTorchRayActor: class DistributedTorchRayActor:
def __init__(self, world_size, rank, local_rank, master_addr, master_port): def __init__(self, world_size, rank, local_rank, master_addr, master_port):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
datefmt='%Y-%m-%d %H:%M:%S') )
self._model = None self._model = None
self._world_size = world_size self._world_size = world_size
self._rank = rank self._rank = rank
...@@ -82,7 +86,7 @@ class DistributedTorchRayActor: ...@@ -82,7 +86,7 @@ class DistributedTorchRayActor:
@staticmethod @staticmethod
def _get_free_port(): def _get_free_port():
with socket.socket() as sock: with socket.socket() as sock:
sock.bind(('', 0)) sock.bind(("", 0))
return sock.getsockname()[1] return sock.getsockname()[1]
def get_master_addr_port(self): def get_master_addr_port(self):
...@@ -90,7 +94,6 @@ class DistributedTorchRayActor: ...@@ -90,7 +94,6 @@ class DistributedTorchRayActor:
class BasePPORole(DistributedTorchRayActor): class BasePPORole(DistributedTorchRayActor):
def add_experience_maker(self, kl_coef: float = 0.1): def add_experience_maker(self, kl_coef: float = 0.1):
self._experience_maker = ExperienceMaker(kl_coef) self._experience_maker = ExperienceMaker(kl_coef)
...@@ -99,12 +102,12 @@ class BasePPORole(DistributedTorchRayActor): ...@@ -99,12 +102,12 @@ class BasePPORole(DistributedTorchRayActor):
def _init_strategy(self, strategy: str): def _init_strategy(self, strategy: str):
# configure strategy # configure strategy
if strategy == 'ddp': if strategy == "ddp":
self._strategy = DDPStrategy() self._strategy = DDPStrategy()
elif strategy == 'colossalai_gemini': elif strategy == "colossalai_gemini":
self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy == 'colossalai_zero2': elif strategy == "colossalai_zero2":
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
raise ValueError(f'Unsupported strategy "{strategy}"') raise ValueError(f'Unsupported strategy "{strategy}"')
...@@ -124,11 +127,9 @@ class BasePPORole(DistributedTorchRayActor): ...@@ -124,11 +127,9 @@ class BasePPORole(DistributedTorchRayActor):
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str): def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
raise NotImplementedError() raise NotImplementedError()
def init_model_from_pretrained(self, def init_model_from_pretrained(
strategy: str, self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False
model_class: Type[LoRAModule], ):
pretrain: str,
has_optimizer=False):
self._init_strategy(strategy) self._init_strategy(strategy)
self._load_model_from_pretrained(model_class, pretrain) self._load_model_from_pretrained(model_class, pretrain)
self._prepare_model_with_strategy(has_optimizer) self._prepare_model_with_strategy(has_optimizer)
...@@ -138,7 +139,6 @@ class BasePPORole(DistributedTorchRayActor): ...@@ -138,7 +139,6 @@ class BasePPORole(DistributedTorchRayActor):
class TrainablePPORole(BasePPORole): class TrainablePPORole(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain): def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context(): with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device()) self._model = model_class(pretrain).to(torch.cuda.current_device())
...@@ -161,38 +161,39 @@ class TrainablePPORole(BasePPORole): ...@@ -161,38 +161,39 @@ class TrainablePPORole(BasePPORole):
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class RayPPOActor(TrainablePPORole): class RayPPOActor(TrainablePPORole):
def set_loss_function(self, eps_clip: float): def set_loss_function(self, eps_clip: float):
self._actor_loss_fn = PolicyLoss(eps_clip) self._actor_loss_fn = PolicyLoss(eps_clip)
def load_tokenizer_from_pretrained(self, model_type: str, pretrained): def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
if model_type == 'gpt2': if model_type == "gpt2":
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained) self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
elif model_type == 'bloom': elif model_type == "bloom":
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained) self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
elif model_type == 'opt': elif model_type == "opt":
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained) self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
else: else:
raise ValueError(f'Unsupported model "{model_type}"') raise ValueError(f'Unsupported model "{model_type}"')
# Set tokenize function for sequence generation # Set tokenize function for sequence generation
def _text_input_tokenize_fn(texts): def _text_input_tokenize_fn(texts):
batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()} return {k: v.cuda() for k, v in batch.items()}
self._sample_tokenize_function = _text_input_tokenize_fn self._sample_tokenize_function = _text_input_tokenize_fn
def setup_generate_kwargs(self, generate_kwargs: dict): def setup_generate_kwargs(self, generate_kwargs: dict):
from coati.trainer.ppo import _set_default_generate_kwargs from coati.trainer.ppo import _set_default_generate_kwargs
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model) self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id
self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url): def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
import pandas as pd import pandas as pd
prompts = pd.read_csv(prompt_url)['prompt']
prompts = pd.read_csv(prompt_url)["prompt"]
self._sampler = self._strategy.setup_sampler(prompts) self._sampler = self._strategy.setup_sampler(prompts)
def _generate(self, input_ids, **generate_kwargs): def _generate(self, input_ids, **generate_kwargs):
...@@ -214,10 +215,9 @@ class RayPPOActor(TrainablePPORole): ...@@ -214,10 +215,9 @@ class RayPPOActor(TrainablePPORole):
def _training_step(self, experience): def _training_step(self, experience):
num_actions = experience.action_mask.size(1) num_actions = experience.action_mask.size(1)
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask) action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self._actor_loss_fn(action_log_probs, actor_loss = self._actor_loss_fn(
experience.action_log_probs, action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
experience.advantages, )
action_mask=experience.action_mask)
self._strategy.backward(actor_loss, self._model, self._optimizer) self._strategy.backward(actor_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer) self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad() self._optimizer.zero_grad()
...@@ -229,17 +229,18 @@ class RayPPOActor(TrainablePPORole): ...@@ -229,17 +229,18 @@ class RayPPOActor(TrainablePPORole):
self._strategy.save_model(self._model, save_path, only_rank0=True) self._strategy.save_model(self._model, save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if should_save_optimizer: if should_save_optimizer:
self._strategy.save_optimizer(self._optimizer, self._strategy.save_optimizer(
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), self._optimizer,
only_rank0=False) "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()),
only_rank0=False,
)
def generate_answer(self, prompt, max_length=30, num_return_sequences=5): def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
encoded_input = self._model_tokenizer(prompt, return_tensors='pt') encoded_input = self._model_tokenizer(prompt, return_tensors="pt")
input_ids = {k: v.cuda() for k, v in encoded_input.items()} input_ids = {k: v.cuda() for k, v in encoded_input.items()}
sequence, _ = self._model.generate(**input_ids, sequence, _ = self._model.generate(
max_length=max_length, **input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences
return_action_mask=False, )
num_return_sequences=num_return_sequences)
token_list = list(sequence.data[0]) token_list = list(sequence.data[0])
output = " ".join([self._model_tokenizer.decode(token) for token in token_list]) output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
return output return output
...@@ -247,18 +248,16 @@ class RayPPOActor(TrainablePPORole): ...@@ -247,18 +248,16 @@ class RayPPOActor(TrainablePPORole):
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class RayPPOCritic(TrainablePPORole): class RayPPOCritic(TrainablePPORole):
def set_loss_function(self, value_clip: float): def set_loss_function(self, value_clip: float):
self._critic_loss_fn = ValueLoss(value_clip) self._critic_loss_fn = ValueLoss(value_clip)
def _training_step(self, experience): def _training_step(self, experience):
values = self._model(experience.sequences, values = self._model(
action_mask=experience.action_mask, experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
attention_mask=experience.attention_mask) )
critic_loss = self._critic_loss_fn(values, critic_loss = self._critic_loss_fn(
experience.values, values, experience.values, experience.reward, action_mask=experience.action_mask
experience.reward, )
action_mask=experience.action_mask)
self._strategy.backward(critic_loss, self._model, self._optimizer) self._strategy.backward(critic_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer) self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad() self._optimizer.zero_grad()
...@@ -272,12 +271,12 @@ class RayPPOCritic(TrainablePPORole): ...@@ -272,12 +271,12 @@ class RayPPOCritic(TrainablePPORole):
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class RayPPORewardModel(BasePPORole): class RayPPORewardModel(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain): def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context(): with self._strategy.model_init_context():
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device()) critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
self._model = RewardModel(deepcopy(critic.model), self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(
deepcopy(critic.value_head)).to(torch.cuda.current_device()) torch.cuda.current_device()
)
@torch.no_grad() @torch.no_grad()
def calculate_r(self, sequence_attention_action_mask): def calculate_r(self, sequence_attention_action_mask):
...@@ -287,7 +286,6 @@ class RayPPORewardModel(BasePPORole): ...@@ -287,7 +286,6 @@ class RayPPORewardModel(BasePPORole):
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class RayPPOInitialModel(BasePPORole): class RayPPOInitialModel(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain): def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context(): with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device()) self._model = model_class(pretrain).to(torch.cuda.current_device())
...@@ -300,8 +298,8 @@ class RayPPOInitialModel(BasePPORole): ...@@ -300,8 +298,8 @@ class RayPPOInitialModel(BasePPORole):
class PPORayActorGroup: class PPORayActorGroup:
""" """
A group of ray actors A group of ray actors
Functions start with 'async' should return list of object refs Functions start with 'async' should return list of object refs
""" """
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None: def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
...@@ -319,8 +317,9 @@ class PPORayActorGroup: ...@@ -319,8 +317,9 @@ class PPORayActorGroup:
pg = placement_group(bundles, strategy="STRICT_SPREAD") pg = placement_group(bundles, strategy="STRICT_SPREAD")
ray.get(pg.ready()) ray.get(pg.ready())
if pg: if pg:
master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( master_actor = self.ray_actor_type.options(
placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None) scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)
).remote(world_size, 0, 0, None, None)
else: else:
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None) master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
self._actor_handlers = [master_actor] self._actor_handlers = [master_actor]
...@@ -331,16 +330,20 @@ class PPORayActorGroup: ...@@ -331,16 +330,20 @@ class PPORayActorGroup:
for rank in range(1, world_size): for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node local_rank = rank % self._num_gpus_per_node
if pg: if pg:
worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( worker_actor = self.ray_actor_type.options(
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote( scheduling_strategy=PlacementGroupSchedulingStrategy(
world_size, rank, local_rank, master_addr, master_port) placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node
)
).remote(world_size, rank, local_rank, master_addr, master_port)
else: else:
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank, worker_actor = self.ray_actor_type.options(num_gpus=1).remote(
master_addr, master_port) world_size, rank, local_rank, master_addr, master_port
)
self._actor_handlers.append(worker_actor) self._actor_handlers.append(worker_actor)
def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str, def async_init_model_from_pretrained(
has_optimizer: bool): self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool
):
return [ return [
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer) actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
for actor in self._actor_handlers for actor in self._actor_handlers
...@@ -348,7 +351,6 @@ class PPORayActorGroup: ...@@ -348,7 +351,6 @@ class PPORayActorGroup:
class TrainableModelRayActorGroup(PPORayActorGroup): class TrainableModelRayActorGroup(PPORayActorGroup):
def async_learn_on_experiences(self, experience_refs): def async_learn_on_experiences(self, experience_refs):
num_actors = len(self._actor_handlers) num_actors = len(self._actor_handlers)
learn_result_refs = [] learn_result_refs = []
...@@ -359,7 +361,6 @@ class TrainableModelRayActorGroup(PPORayActorGroup): ...@@ -359,7 +361,6 @@ class TrainableModelRayActorGroup(PPORayActorGroup):
class PPOActorRayActorGroup(TrainableModelRayActorGroup): class PPOActorRayActorGroup(TrainableModelRayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None: def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor) super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
...@@ -381,7 +382,8 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup): ...@@ -381,7 +382,8 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
action_log_probs_refs = [] action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)): for i in range(len(sequences_attention_mask_action_mask_refs)):
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote( action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
sequences_attention_mask_action_mask_refs[i]) sequences_attention_mask_action_mask_refs[i]
)
action_log_probs_refs.append(action_log_probs_ref) action_log_probs_refs.append(action_log_probs_ref)
return action_log_probs_refs return action_log_probs_refs
...@@ -393,7 +395,6 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup): ...@@ -393,7 +395,6 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
class PPOCriticRayActorGroup(TrainableModelRayActorGroup): class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None: def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic) super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
...@@ -402,7 +403,8 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup): ...@@ -402,7 +403,8 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
value_refs = [] value_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)): for i in range(len(sequences_attention_mask_action_mask_refs)):
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote( value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
sequences_attention_mask_action_mask_refs[i]) sequences_attention_mask_action_mask_refs[i]
)
value_refs.append(value_ref) value_refs.append(value_ref)
return value_refs return value_refs
...@@ -411,7 +413,6 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup): ...@@ -411,7 +413,6 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
class PPOInitialRayActorGroup(PPORayActorGroup): class PPOInitialRayActorGroup(PPORayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None: def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel) super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
...@@ -420,13 +421,13 @@ class PPOInitialRayActorGroup(PPORayActorGroup): ...@@ -420,13 +421,13 @@ class PPOInitialRayActorGroup(PPORayActorGroup):
base_action_log_probs_refs = [] base_action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)): for i in range(len(sequences_attention_mask_action_mask_refs)):
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote( base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
sequences_attention_mask_action_mask_refs[i]) sequences_attention_mask_action_mask_refs[i]
)
base_action_log_probs_refs.append(base_action_log_probs_ref) base_action_log_probs_refs.append(base_action_log_probs_ref)
return base_action_log_probs_refs return base_action_log_probs_refs
class PPORewardRayActorGroup(PPORayActorGroup): class PPORewardRayActorGroup(PPORayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None: def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel) super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
...@@ -435,20 +436,21 @@ class PPORewardRayActorGroup(PPORayActorGroup): ...@@ -435,20 +436,21 @@ class PPORewardRayActorGroup(PPORayActorGroup):
r_refs = [] r_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)): for i in range(len(sequences_attention_mask_action_mask_refs)):
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote( r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
sequences_attention_mask_action_mask_refs[i]) sequences_attention_mask_action_mask_refs[i]
)
r_refs.append(r_ref) r_refs.append(r_ref)
return r_refs return r_refs
def main(args): def main(args):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
datefmt='%Y-%m-%d %H:%M:%S') )
if args.model == 'gpt2': if args.model == "gpt2":
actor_model_class, critic_model_class = GPTActor, GPTCritic actor_model_class, critic_model_class = GPTActor, GPTCritic
elif args.model == 'bloom': elif args.model == "bloom":
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
elif args.model == 'opt': elif args.model == "opt":
actor_model_class, critic_model_class = OPTActor, OPTCritic actor_model_class, critic_model_class = OPTActor, OPTCritic
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -462,13 +464,14 @@ def main(args): ...@@ -462,13 +464,14 @@ def main(args):
logging.info("Actors created") logging.info("Actors created")
# Prepare model for training # Prepare model for training
generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50} generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50}
ray.get( ray.get(
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True)
critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True)
initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False)
reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False)
actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)) + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)
)
logging.info("Models prepared for training") logging.info("Models prepared for training")
# Prepare models for training # Prepare models for training
...@@ -483,8 +486,12 @@ def main(args): ...@@ -483,8 +486,12 @@ def main(args):
# Start training # Start training
logging.info("Training start") logging.info("Training start")
# Set all models to eval and add experience maker # Set all models to eval and add experience maker
all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \ all_ray_actors = (
initial_group._actor_handlers + reward_group._actor_handlers actor_group._actor_handlers
+ critic_group._actor_handlers
+ initial_group._actor_handlers
+ reward_group._actor_handlers
)
num_ray_actors = len(all_ray_actors) num_ray_actors = len(all_ray_actors)
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors]) ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors]) ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
...@@ -497,18 +504,28 @@ def main(args): ...@@ -497,18 +504,28 @@ def main(args):
time += 1 time += 1
# Experience queueing stage # Experience queueing stage
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence( sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
experience_batch_size) experience_batch_size
)
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs( base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
sequences_attention_mask_action_mask_refs) sequences_attention_mask_action_mask_refs
)
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs) values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs) r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
action_log_probs_refs = actor_group.async_calculate_action_log_probs( action_log_probs_refs = actor_group.async_calculate_action_log_probs(
sequences_attention_mask_action_mask_refs) sequences_attention_mask_action_mask_refs
experience_composition_refs.extend([ )
ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i], experience_composition_refs.extend(
base_action_log_probs_refs[i], values_refs[i], r_refs[i]) [
for i in range(len(sequences_attention_mask_action_mask_refs)) ExperienceCompositionRefs(
]) sequences_attention_mask_action_mask_refs[i],
action_log_probs_refs[i],
base_action_log_probs_refs[i],
values_refs[i],
r_refs[i],
)
for i in range(len(sequences_attention_mask_action_mask_refs))
]
)
# Learning stage # Learning stage
if time % update_timesteps == 0: if time % update_timesteps == 0:
experience_refs = [] experience_refs = []
...@@ -519,8 +536,9 @@ def main(args): ...@@ -519,8 +536,9 @@ def main(args):
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref)) experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
# backward # backward
ray.get( ray.get(
actor_group.async_learn_on_experiences(experience_refs) + actor_group.async_learn_on_experiences(experience_refs)
critic_group.async_learn_on_experiences(experience_refs)) + critic_group.async_learn_on_experiences(experience_refs)
)
# clear refs queue # clear refs queue
experience_composition_refs.clear() experience_composition_refs.clear()
logging.info("Training finished") logging.info("Training finished")
...@@ -528,26 +546,24 @@ def main(args): ...@@ -528,26 +546,24 @@ def main(args):
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt) actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--prompt_csv_url', type=str) parser.add_argument("--prompt_csv_url", type=str)
parser.add_argument('--strategy', parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"])
default='ddp') parser.add_argument("--pretrain", type=str, default="gpt2")
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt")
parser.add_argument('--pretrain', type=str, default='gpt2') parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument("--max_timesteps", type=int, default=10)
parser.add_argument('--num_episodes', type=int, default=10) parser.add_argument("--update_timesteps", type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=10) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1)
parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1) parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1)
parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1) parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1)
parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1) parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1)
parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
args = parser.parse_args() args = parser.parse_args()
ray.init() ray.init()
main(args) main(args)
...@@ -22,7 +22,7 @@ class HFRepoFiles: ...@@ -22,7 +22,7 @@ class HFRepoFiles:
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path) file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
def download_all(self): def download_all(self):
file_path = snapshot_download(self.repo_id) snapshot_download(self.repo_id)
def test_init(model: str, dir_path: str): def test_init(model: str, dir_path: str):
...@@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str): ...@@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str):
actor = GPTActor(config=config) actor = GPTActor(config=config)
critic = GPTCritic(config=config) critic = GPTCritic(config=config)
reward_model = GPTRM(config=config) reward_model = GPTRM(config=config)
tokenizer = GPT2Tokenizer.from_pretrained(dir_path) GPT2Tokenizer.from_pretrained(dir_path)
elif model == "bloom": elif model == "bloom":
config = BloomConfig.from_pretrained(dir_path) config = BloomConfig.from_pretrained(dir_path)
actor = BLOOMActor(config=config) actor = BLOOMActor(config=config)
critic = BLOOMCritic(config=config) critic = BLOOMCritic(config=config)
reward_model = BLOOMRM(config=config) reward_model = BLOOMRM(config=config)
tokenizer = BloomTokenizerFast.from_pretrained(dir_path) BloomTokenizerFast.from_pretrained(dir_path)
elif model == "opt": elif model == "opt":
config = AutoConfig.from_pretrained(dir_path) config = AutoConfig.from_pretrained(dir_path)
actor = OPTActor(config=config) actor = OPTActor(config=config)
critic = OPTCritic(config=config) critic = OPTCritic(config=config)
reward_model = OPTRM(config=config) reward_model = OPTRM(config=config)
tokenizer = AutoTokenizer.from_pretrained(dir_path) AutoTokenizer.from_pretrained(dir_path)
else: else:
raise NotImplementedError(f"Model {model} not implemented") raise NotImplementedError(f"Model {model} not implemented")
...@@ -59,17 +59,12 @@ if __name__ == "__main__": ...@@ -59,17 +59,12 @@ if __name__ == "__main__":
exit(0) exit(0)
repo_list = { repo_list = {
"gpt2": HFRepoFiles( "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
repo_id="gpt2",
files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
),
"bloom": HFRepoFiles( "bloom": HFRepoFiles(
repo_id="bigscience/bloom-560m", repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
files=["config.json", "tokenizer.json", "tokenizer_config.json"]
), ),
"opt": HFRepoFiles( "opt": HFRepoFiles(
repo_id="facebook/opt-350m", repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
), ),
} }
......
...@@ -31,9 +31,11 @@ def generate_alpaca(): ...@@ -31,9 +31,11 @@ def generate_alpaca():
def generate_sharegpt(): def generate_sharegpt():
# ShareGPT data requires less processing. # ShareGPT data requires less processing.
conversation_dataset = [] conversation_dataset = []
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered", dataset = load_dataset(
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json", "anon8231489123/ShareGPT_Vicuna_unfiltered",
split="train") data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train",
)
conversations = dataset["conversations"] conversations = dataset["conversations"]
...@@ -43,23 +45,24 @@ def generate_sharegpt(): ...@@ -43,23 +45,24 @@ def generate_sharegpt():
del conv["markdown"] del conv["markdown"]
del conv["text"] del conv["text"]
conversation = dict(type="conversation", conversation = dict(
language="Multilingual", type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
dataset="ShareGPT", )
conversations=conversations[idx])
conversation_dataset.append(conversation) conversation_dataset.append(conversation)
return conversation_dataset return conversation_dataset
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset', parser.add_argument(
type=str, "--dataset",
default="All", type=str,
choices=["Alpaca", "ShareGPT", "All"], default="All",
help="which dataset to convert, All will combine Alpaca and ShareGPT") choices=["Alpaca", "ShareGPT", "All"],
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset") help="which dataset to convert, All will combine Alpaca and ShareGPT",
)
parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args() args = parser.parse_args()
conversation_dataset = [] conversation_dataset = []
...@@ -75,5 +78,5 @@ if __name__ == '__main__': ...@@ -75,5 +78,5 @@ if __name__ == '__main__':
for idx, sample in enumerate(conversation_dataset): for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1 sample["id"] = idx + 1
with open(args.save_path, mode='w') as f: with open(args.save_path, mode="w") as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False) json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
...@@ -6,7 +6,7 @@ random.seed(42) ...@@ -6,7 +6,7 @@ random.seed(42)
def sample(args): def sample(args):
with open(args.dataset_path, mode='r') as f: with open(args.dataset_path, mode="r") as f:
dataset_list = json.load(f) dataset_list = json.load(f)
sampled_dataset = [ sampled_dataset = [
...@@ -14,18 +14,14 @@ def sample(args): ...@@ -14,18 +14,14 @@ def sample(args):
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size)) for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
] ]
with open(args.save_path, mode='w') as f: with open(args.save_path, mode="w") as f:
json.dump(sampled_dataset, f, indent=4, json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
default=str, ensure_ascii=False)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default=None, parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
required=True, help="path to the pretrain dataset") parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
parser.add_argument('--save_path', type=str, default='prompt.json', parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
help="path to save the prompt dataset")
parser.add_argument('--sample_size', type=int,
default=16384, help="size of the prompt dataset")
args = parser.parse_args() args = parser.parse_args()
sample(args) sample(args)
...@@ -11,13 +11,13 @@ from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, Llama ...@@ -11,13 +11,13 @@ from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, Llama
def eval(args): def eval(args):
# configure model # configure model
if args.model == 'gpt2': if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain) actor = GPTActor(pretrained=args.pretrain)
elif args.model == 'bloom': elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain) actor = BLOOMActor(pretrained=args.pretrain)
elif args.model == 'opt': elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain) actor = OPTActor(pretrained=args.pretrain)
elif args.model == 'llama': elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain) actor = LlamaActor(pretrained=args.pretrain)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -28,45 +28,38 @@ def eval(args): ...@@ -28,45 +28,38 @@ def eval(args):
actor.load_state_dict(state_dict) actor.load_state_dict(state_dict)
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.eos_token = '<\s>' tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
actor.eval() actor.eval()
input_ids = tokenizer.encode(args.input, input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
return_tensors='pt')\ outputs = generate(
.to(torch.cuda.current_device()) actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1
outputs = generate(actor, )
input_ids, output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
max_length=args.max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1)
output = tokenizer.batch_decode(outputs[0],
skip_special_tokens=True)
print(f"[Output]: {''.join(output)}") print(f"[Output]: {''.join(output)}")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--model_path', type=str, default=None) parser.add_argument("--model_path", type=str, default=None)
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:') parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
parser.add_argument('--max_length', type=int, default=100) parser.add_argument("--max_length", type=int, default=100)
args = parser.parse_args() args = parser.parse_args()
eval(args) eval(args)
...@@ -5,7 +5,6 @@ from functools import partial ...@@ -5,7 +5,6 @@ from functools import partial
import pandas as pd import pandas as pd
import ray import ray
import torch
from coati.quant import llama_load_quant, low_resource_init from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder from coati.ray.experience_maker_holder import ExperienceMakerHolder
...@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights ...@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port(): def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) s.bind(("", 0))
return s.getsockname()[1] return s.getsockname()[1]
def get_local_ip(): def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80)) s.connect(("8.8.8.8", 80))
return s.getsockname()[0] return s.getsockname()[0]
...@@ -37,22 +36,25 @@ def main(args): ...@@ -37,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip()) master_addr = str(get_local_ip())
# trainer_env_info # trainer_env_info
trainer_port = str(get_free_port()) trainer_port = str(get_free_port())
env_info_trainers = [{ env_info_trainers = [
'local_rank': '0', {
'rank': str(rank), "local_rank": "0",
'world_size': str(args.num_trainers), "rank": str(rank),
'master_port': trainer_port, "world_size": str(args.num_trainers),
'master_addr': master_addr "master_port": trainer_port,
} for rank in range(args.num_trainers)] "master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info # maker_env_info
maker_port = str(get_free_port()) maker_port = str(get_free_port())
env_info_maker = { env_info_maker = {
'local_rank': '0', "local_rank": "0",
'rank': '0', "rank": "0",
'world_size': '1', "world_size": "1",
'master_port': maker_port, "master_port": maker_port,
'master_addr': master_addr "master_addr": master_addr,
} }
# configure tokenizer # configure tokenizer
...@@ -75,27 +77,33 @@ def main(args): ...@@ -75,27 +77,33 @@ def main(args):
eval_performance=True, eval_performance=True,
debug=args.debug, debug=args.debug,
update_lora_weights=not (args.lora_rank == 0), update_lora_weights=not (args.lora_rank == 0),
) for i, env_info_trainer in enumerate(env_info_trainers) )
for i, env_info_trainer in enumerate(env_info_trainers)
] ]
def model_fn(): def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama': if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model # quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain) actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights(): with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg) initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, initial_model.model = (
args.quant_group_size).cuda().requires_grad_(False) llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else: else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model return actor, critic, reward_model, initial_model
# configure Experience Maker # configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy), strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn, model_fn=model_fn,
env_info=env_info_maker, env_info=env_info_maker,
...@@ -130,12 +138,11 @@ def main(args): ...@@ -130,12 +138,11 @@ def main(args):
dataset_size = args.experience_batch_size * 4 dataset_size = args.experience_batch_size * 4
def build_dataloader(): def build_dataloader():
def tokenize_fn(texts): def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()} return {k: v.cuda() for k, v in batch.items()}
dataset = pd.read_csv(args.prompt_path)['prompt'] dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader return dataloader
...@@ -144,32 +151,31 @@ def main(args): ...@@ -144,32 +151,31 @@ def main(args):
ray.get(wait_tasks) ray.get(wait_tasks)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--prompt_path', type=str, default=None) parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument('--num_trainers', type=int, default=1) parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument('--trainer_strategy', parser.add_argument(
choices=[ "--trainer_strategy",
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
'colossalai_zero2_cpu' default="ddp",
], )
default='ddp') parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument('--maker_strategy', choices=['naive'], default='naive') parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None) parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument('--experience_steps', type=int, default=4) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument('--train_epochs', type=int, default=1) parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument('--update_steps', type=int, default=2) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument('--quant_bits', type=int, default=4) parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument('--quant_group_size', type=int, default=128) parser.add_argument("--debug", action="store_true")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args() args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args) main(args)
...@@ -5,7 +5,6 @@ from functools import partial ...@@ -5,7 +5,6 @@ from functools import partial
import pandas as pd import pandas as pd
import ray import ray
import torch
from coati.quant import llama_load_quant, low_resource_init from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder from coati.ray.experience_maker_holder import ExperienceMakerHolder
...@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights ...@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port(): def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) s.bind(("", 0))
return s.getsockname()[1] return s.getsockname()[1]
def get_local_ip(): def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80)) s.connect(("8.8.8.8", 80))
return s.getsockname()[0] return s.getsockname()[0]
...@@ -37,23 +36,29 @@ def main(args): ...@@ -37,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip()) master_addr = str(get_local_ip())
# trainer_env_info # trainer_env_info
trainer_port = str(get_free_port()) trainer_port = str(get_free_port())
env_info_trainers = [{ env_info_trainers = [
'local_rank': '0', {
'rank': str(rank), "local_rank": "0",
'world_size': str(args.num_trainers), "rank": str(rank),
'master_port': trainer_port, "world_size": str(args.num_trainers),
'master_addr': master_addr "master_port": trainer_port,
} for rank in range(args.num_trainers)] "master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info # maker_env_info
maker_port = str(get_free_port()) maker_port = str(get_free_port())
env_info_makers = [{ env_info_makers = [
'local_rank': '0', {
'rank': str(rank), "local_rank": "0",
'world_size': str(args.num_makers), "rank": str(rank),
'master_port': maker_port, "world_size": str(args.num_makers),
'master_addr': master_addr "master_port": maker_port,
} for rank in range(args.num_makers)] "master_addr": master_addr,
}
for rank in range(args.num_makers)
]
# configure tokenizer # configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain) tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
...@@ -63,13 +68,18 @@ def main(args): ...@@ -63,13 +68,18 @@ def main(args):
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama': if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model # quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain) actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights(): with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg) initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, initial_model.model = (
args.quant_group_size).cuda().requires_grad_(False) llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else: else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model return actor, critic, reward_model, initial_model
...@@ -78,7 +88,7 @@ def main(args): ...@@ -78,7 +88,7 @@ def main(args):
experience_holder_refs = [ experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[ detached_trainer_name_list=[
f'trainer{x}' f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
], ],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy), strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
...@@ -87,8 +97,8 @@ def main(args): ...@@ -87,8 +97,8 @@ def main(args):
kl_coef=0.1, kl_coef=0.1,
debug=args.debug, debug=args.debug,
update_lora_weights=not (args.lora_rank == 0), update_lora_weights=not (args.lora_rank == 0),
# sync_models_from_trainers=True, # sync_models_from_trainers=True,
# generation kwargs: # generation kwargs:
max_length=512, max_length=512,
do_sample=True, do_sample=True,
temperature=1.0, temperature=1.0,
...@@ -128,12 +138,11 @@ def main(args): ...@@ -128,12 +138,11 @@ def main(args):
dataset_size = args.experience_batch_size * 4 dataset_size = args.experience_batch_size * 4
def build_dataloader(): def build_dataloader():
def tokenize_fn(texts): def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()} return {k: v.cuda() for k, v in batch.items()}
dataset = pd.read_csv(args.prompt_path)['prompt'] dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader return dataloader
...@@ -148,39 +157,44 @@ def main(args): ...@@ -148,39 +157,44 @@ def main(args):
for experience_holder_ref in experience_holder_refs: for experience_holder_ref in experience_holder_refs:
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps)) wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
total_steps = args.experience_batch_size * args.experience_steps * \ total_steps = (
args.num_makers // (args.num_trainers * args.train_batch_size) args.experience_batch_size
* args.experience_steps
* args.num_makers
// (args.num_trainers * args.train_batch_size)
)
for trainer_ref in trainer_refs: for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks) ray.get(wait_tasks)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--prompt_path', type=str, default=None) parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument('--num_makers', type=int, default=1) parser.add_argument("--num_makers", type=int, default=1)
parser.add_argument('--num_trainers', type=int, default=1) parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument( parser.add_argument(
'--trainer_strategy', "--trainer_strategy",
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'], choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default='ddp') default="ddp",
parser.add_argument('--maker_strategy', choices=['naive'], default='naive') )
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--critic_pretrain', type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4) parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument('--train_epochs', type=int, default=1) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--update_steps', type=int, default=2) parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4) parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument('--quant_group_size', type=int, default=128) parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument('--debug', action='store_true') parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args() args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
......
pandas>=1.4.1 pandas>=1.4.1
sentencepiece sentencepiece
colossalai==0.3.1 colossalai==0.3.1
\ No newline at end of file
...@@ -20,28 +20,28 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -20,28 +20,28 @@ from colossalai.nn.optimizer import HybridAdam
def main(args): def main(args):
# configure strategy # configure strategy
if args.strategy == 'ddp': if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini': elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif args.strategy == 'colossalai_zero2': elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None: if args.rm_path is not None:
warnings.warn('LoRA weights should be merged with the model weights') warnings.warn("LoRA weights should be merged with the model weights")
state_dict = torch.load(args.rm_path, map_location='cpu') state_dict = torch.load(args.rm_path, map_location="cpu")
with strategy.model_init_context(): with strategy.model_init_context():
# configure model # configure model
if args.model == 'gpt2': if args.model == "gpt2":
initial_model = GPTActor(pretrained=args.pretrain) initial_model = GPTActor(pretrained=args.pretrain)
elif args.model == 'bloom': elif args.model == "bloom":
initial_model = BLOOMActor(pretrained=args.pretrain) initial_model = BLOOMActor(pretrained=args.pretrain)
elif args.model == 'opt': elif args.model == "opt":
initial_model = OPTActor(pretrained=args.pretrain) initial_model = OPTActor(pretrained=args.pretrain)
elif args.model == 'llama': elif args.model == "llama":
initial_model = LlamaActor(pretrained=args.pretrain) initial_model = LlamaActor(pretrained=args.pretrain)
else: else:
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
...@@ -51,13 +51,13 @@ def main(args): ...@@ -51,13 +51,13 @@ def main(args):
else: else:
rm_model_name = args.rm_model rm_model_name = args.rm_model
if rm_model_name == 'gpt2': if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'bloom': elif rm_model_name == "bloom":
reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'opt': elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'llama': elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
...@@ -68,24 +68,24 @@ def main(args): ...@@ -68,24 +68,24 @@ def main(args):
initial_model.to(torch.float16).to(torch.cuda.current_device()) initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device())
if args.model == 'gpt2': if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'bloom': elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'opt': elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'llama': elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == 'gpt2': if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'bloom': elif rm_model_name == "bloom":
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'opt': elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'llama': elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
...@@ -94,12 +94,12 @@ def main(args): ...@@ -94,12 +94,12 @@ def main(args):
critic.load_state_dict(state_dict, strict=False) critic.load_state_dict(state_dict, strict=False)
del state_dict del state_dict
if args.strategy != 'colossalai_gemini': if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device()) critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7) actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else: else:
...@@ -107,22 +107,22 @@ def main(args): ...@@ -107,22 +107,22 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7) critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained( tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained( tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
tokenizer.eos_token = '<\s>' )
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -132,27 +132,25 @@ def main(args): ...@@ -132,27 +132,25 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else: else:
prompt_sampler = None prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset, prompt_dataloader = DataLoader(
shuffle=(prompt_sampler is None), prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
sampler=prompt_sampler, )
batch_size=args.experience_batch_size)
pretrain_dataset = SupervisedDataset(
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len
data_path=args.pretrain_dataset, )
max_datasets_size=16384,
max_length=args.max_input_len)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else: else:
pretrain_sampler = None pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset, pretrain_dataloader = DataLoader(
shuffle=(pretrain_sampler is None), pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
sampler=pretrain_sampler, )
batch_size=args.ptx_batch_size)
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized. # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
)
# configure trainer # configure trainer
trainer = PPOTrainer( trainer = PPOTrainer(
...@@ -173,50 +171,54 @@ def main(args): ...@@ -173,50 +171,54 @@ def main(args):
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
offload_inference_models=args.strategy != 'colossalai_gemini' offload_inference_models=args.strategy != "colossalai_gemini",
) )
trainer.fit(prompt_dataloader=prompt_dataloader, trainer.fit(
pretrain_dataloader=pretrain_dataloader, prompt_dataloader=prompt_dataloader,
num_episodes=args.num_episodes, pretrain_dataloader=pretrain_dataloader,
num_collect_steps=args.num_collect_steps, num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps) num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps,
)
# save model checkpoint after fitting # save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True) strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim, strategy.save_optimizer(
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
only_rank0=False) )
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset') parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument('--strategy', parser.add_argument(
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], "--strategy",
default='colossalai_zero2', choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
help='strategy to use') default="colossalai_zero2",
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) help="strategy to use",
parser.add_argument('--tokenizer', type=str, default=None) )
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--rm_path', type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--rm_pretrain', type=str, default=None) parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') parser.add_argument("--rm_path", type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument('--num_episodes', type=int, default=10) parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument('--num_collect_steps', type=int, default=10) parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument('--num_update_steps', type=int, default=5) parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--num_collect_steps", type=int, default=10)
parser.add_argument('--ptx_batch_size', type=int, default=1) parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument('--kl_coef', type=float, default=0.1) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--ptx_coef', type=float, default=0.9) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--max_input_len', type=int, default=96) parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument('--max_seq_len', type=int, default=128) parser.add_argument("--ptx_coef", type=float, default=0.9)
parser.add_argument("--max_input_len", type=int, default=96)
parser.add_argument("--max_seq_len", type=int, default=128)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -24,24 +24,24 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -24,24 +24,24 @@ from colossalai.nn.optimizer import HybridAdam
def train(args): def train(args):
# configure strategy # configure strategy
if args.strategy == 'ddp': if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini': elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy='cuda') strategy = GeminiStrategy(placement_policy="cuda")
elif args.strategy == 'colossalai_zero2': elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model # configure model
with strategy.model_init_context(): with strategy.model_init_context():
if args.model == 'bloom': if args.model == "bloom":
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank) model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'opt': elif args.model == "opt":
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'gpt2': elif args.model == "gpt2":
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'llama': elif args.model == "llama":
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank) model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -53,36 +53,36 @@ def train(args): ...@@ -53,36 +53,36 @@ def train(args):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained( tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained( tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
tokenizer.eos_token = '<\s>' )
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=5e-6) optim = HybridAdam(model.parameters(), lr=5e-6)
else: else:
optim = Adam(model.parameters(), lr=5e-6) optim = Adam(model.parameters(), lr=5e-6)
# configure loss function # configure loss function
if args.loss_fn == 'log_sig': if args.loss_fn == "log_sig":
loss_fn = LogSigLoss() loss_fn = LogSigLoss()
elif args.loss_fn == 'log_exp': elif args.loss_fn == "log_exp":
loss_fn = LogExpLoss() loss_fn = LogExpLoss()
else: else:
raise ValueError(f'Unsupported loss function "{args.loss_fn}"') raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
...@@ -94,18 +94,18 @@ def train(args): ...@@ -94,18 +94,18 @@ def train(args):
data = load_dataset(args.dataset) data = load_dataset(args.dataset)
if args.test: if args.test:
train_data = data['train'].select(range(20)) train_data = data["train"].select(range(20))
eval_data = data['test'].select(range(5)) eval_data = data["test"].select(range(5))
else: else:
train_data = data['train'] train_data = data["train"]
eval_data = data['test'] eval_data = data["test"]
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
if args.dataset == 'Dahoas/rm-static': if args.dataset == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len) valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
elif args.dataset == 'Anthropic/hh-rlhf': elif args.dataset == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len) valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
...@@ -113,90 +113,99 @@ def train(args): ...@@ -113,90 +113,99 @@ def train(args):
raise ValueError(f'Unsupported dataset "{args.dataset}"') raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, train_sampler = DistributedSampler(
shuffle=True, train_dataset,
seed=42, shuffle=True,
drop_last=True, seed=42,
rank=dist.get_rank(), drop_last=True,
num_replicas=dist.get_world_size()) rank=dist.get_rank(),
valid_sampler = DistributedSampler(valid_dataset, num_replicas=dist.get_world_size(),
shuffle=True, )
seed=42, valid_sampler = DistributedSampler(
drop_last=True, valid_dataset,
rank=dist.get_rank(), shuffle=True,
num_replicas=dist.get_world_size()) seed=42,
eval_sampler = DistributedSampler(eval_dataset, drop_last=True,
shuffle=True, rank=dist.get_rank(),
seed=42, num_replicas=dist.get_world_size(),
drop_last=True, )
rank=dist.get_rank(), eval_sampler = DistributedSampler(
num_replicas=dist.get_world_size()) eval_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
else: else:
train_sampler = None train_sampler = None
valid_sampler = None valid_sampler = None
eval_sampler = None eval_sampler = None
train_dataloader = DataLoader(train_dataset, train_dataloader = DataLoader(
shuffle=(train_sampler is None), train_dataset,
sampler=train_sampler, shuffle=(train_sampler is None),
batch_size=args.batch_size, sampler=train_sampler,
pin_memory=True) batch_size=args.batch_size,
pin_memory=True,
valid_dataloader = DataLoader(valid_dataset, )
shuffle=(valid_sampler is None),
sampler=valid_sampler, valid_dataloader = DataLoader(
batch_size=args.batch_size, valid_dataset,
pin_memory=True) shuffle=(valid_sampler is None),
sampler=valid_sampler,
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size,
shuffle=(eval_sampler is None), pin_memory=True,
sampler=eval_sampler, )
batch_size=args.batch_size,
pin_memory=True) eval_dataloader = DataLoader(
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
)
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100) lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model'] model = strategy_dict["model"]
optim = strategy_dict['optimizer'] optim = strategy_dict["optimizer"]
lr_scheduler = strategy_dict['lr_scheduler'] lr_scheduler = strategy_dict["lr_scheduler"]
trainer = RewardModelTrainer(model=model, trainer = RewardModelTrainer(
strategy=strategy, model=model,
optim=optim, strategy=strategy,
lr_scheduler=lr_scheduler, optim=optim,
loss_fn=loss_fn, lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs) loss_fn=loss_fn,
max_epochs=args.max_epochs,
)
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader) trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True) strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer, strategy.save_optimizer(
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
only_rank0=False) )
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--strategy', parser.add_argument(
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
default='colossalai_zero2') )
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--model_path', type=str, default=None) parser.add_argument("--model_path", type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument('--dataset', parser.add_argument(
type=str, "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], )
default='Dahoas/rm-static') parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None) parser.add_argument("--save_path", type=str, default="rm_ckpt")
parser.add_argument('--save_path', type=str, default='rm_ckpt') parser.add_argument("--max_epochs", type=int, default=1)
parser.add_argument('--max_epochs', type=int, default=1) parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument('--max_len', type=int, default=512) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp']) parser.add_argument("--test", type=bool, default=False)
parser.add_argument('--test', type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)
...@@ -6,18 +6,18 @@ import torch ...@@ -6,18 +6,18 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from coati.dataset import SFTDataset, SupervisedDataset from coati.dataset import SFTDataset, SupervisedDataset
from coati.models.bloom import BLOOMActor from coati.models.bloom import BLOOMActor
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.gpt import GPTActor from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor from coati.models.opt import OPTActor
from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler from transformers.trainer import get_scheduler
...@@ -28,14 +28,14 @@ from colossalai.tensor import ColoParameter ...@@ -28,14 +28,14 @@ from colossalai.tensor import ColoParameter
def train(args): def train(args):
# configure strategy # configure strategy
if args.strategy == 'ddp': if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini': elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy='cuda') strategy = GeminiStrategy(placement_policy="cuda")
elif args.strategy == 'colossalai_zero2': elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == 'colossalai_zero2_cpu': elif args.strategy == "colossalai_zero2_cpu":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else: else:
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
...@@ -44,23 +44,15 @@ def train(args): ...@@ -44,23 +44,15 @@ def train(args):
warnings.warn("Gradient checkpoint is disabled when using LoRA") warnings.warn("Gradient checkpoint is disabled when using LoRA")
args.grad_checkpoint = False args.grad_checkpoint = False
with strategy.model_init_context(): with strategy.model_init_context():
if args.model == 'bloom': if args.model == "bloom":
model = BLOOMActor(pretrained=args.pretrain, model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
lora_rank=args.lora_rank, elif args.model == "opt":
checkpoint=args.grad_checkpoint) model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
elif args.model == 'opt': elif args.model == "gpt2":
model = OPTActor(pretrained=args.pretrain, model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
lora_rank=args.lora_rank, elif args.model == "llama":
checkpoint=args.grad_checkpoint) model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
elif args.model == 'gpt2': elif args.model == "chatglm":
model = GPTActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'llama':
model = LlamaActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'chatglm':
model = ChatGLMActor(pretrained=args.pretrain) model = ChatGLMActor(pretrained=args.pretrain)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -68,144 +60,157 @@ def train(args): ...@@ -68,144 +60,157 @@ def train(args):
model.to(torch.float16).to(torch.cuda.current_device()) model.to(torch.float16).to(torch.cuda.current_device())
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained( tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained( tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
tokenizer.eos_token = '<\s>' )
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
elif args.model == 'chatglm': elif args.model == "chatglm":
tokenizer = ChatGLMTokenizer.from_pretrained( tokenizer = ChatGLMTokenizer.from_pretrained(
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True) "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
if args.model == 'llama' and args.strategy == 'colossalai_gemini': if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding # this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not isinstance(param, ColoParameter): if not isinstance(param, ColoParameter):
sub_module_name = '.'.join(name.split('.')[:-1]) sub_module_name = ".".join(name.split(".")[:-1])
weight_name = name.split('.')[-1] weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name) sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param)) setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else: else:
optim = Adam(model.parameters(), lr=args.lr) optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger() logger = get_dist_logger()
# configure dataset # configure dataset
if args.dataset == 'yizhongw/self_instruct': if args.dataset == "yizhongw/self_instruct":
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
train_dataset = SFTDataset(train_data, tokenizer, args.max_len) train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
else: else:
train_dataset = SupervisedDataset(tokenizer=tokenizer, train_dataset = SupervisedDataset(
data_path=args.dataset, tokenizer=tokenizer,
max_datasets_size=args.max_datasets_size, data_path=args.dataset,
max_length=args.max_len) max_datasets_size=args.max_datasets_size,
max_length=args.max_len,
)
eval_dataset = None eval_dataset = None
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, train_sampler = DistributedSampler(
shuffle=True, train_dataset,
seed=42, shuffle=True,
drop_last=True, seed=42,
rank=dist.get_rank(), drop_last=True,
num_replicas=dist.get_world_size()) rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
if eval_dataset is not None: if eval_dataset is not None:
eval_sampler = DistributedSampler(eval_dataset, eval_sampler = DistributedSampler(
shuffle=False, eval_dataset,
seed=42, shuffle=False,
drop_last=False, seed=42,
rank=dist.get_rank(), drop_last=False,
num_replicas=dist.get_world_size()) rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
else: else:
train_sampler = None train_sampler = None
eval_sampler = None eval_sampler = None
train_dataloader = DataLoader(train_dataset, train_dataloader = DataLoader(
shuffle=(train_sampler is None), train_dataset,
sampler=train_sampler, shuffle=(train_sampler is None),
batch_size=args.batch_size, sampler=train_sampler,
pin_memory=True) batch_size=args.batch_size,
pin_memory=True,
)
if eval_dataset is not None: if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset, eval_dataloader = DataLoader(
shuffle=(eval_sampler is None), eval_dataset,
sampler=eval_sampler, shuffle=(eval_sampler is None),
batch_size=args.batch_size, sampler=eval_sampler,
pin_memory=True) batch_size=args.batch_size,
pin_memory=True,
)
else: else:
eval_dataloader = None eval_dataloader = None
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
lr_scheduler = get_scheduler("cosine", lr_scheduler = get_scheduler(
optim, "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
num_warmup_steps=math.ceil(max_steps * 0.03), )
num_training_steps=max_steps)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model'] model = strategy_dict["model"]
optim = strategy_dict['optimizer'] optim = strategy_dict["optimizer"]
lr_scheduler = strategy_dict['lr_scheduler'] lr_scheduler = strategy_dict["lr_scheduler"]
trainer = SFTTrainer(model=model, trainer = SFTTrainer(
strategy=strategy, model=model,
optim=optim, strategy=strategy,
lr_scheduler=lr_scheduler, optim=optim,
max_epochs=args.max_epochs, lr_scheduler=lr_scheduler,
accumulation_steps=args.accumulation_steps) max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
trainer.fit(train_dataloader=train_dataloader, )
eval_dataloader=eval_dataloader,
logger=logger, trainer.fit(
use_wandb=args.use_wandb) train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb
)
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer, strategy.save_optimizer(
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
only_rank0=False) )
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--strategy', parser.add_argument(
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], "--strategy",
default='colossalai_zero2') choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom') default="colossalai_zero2",
parser.add_argument('--tokenizer', type=str, default=None) )
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
parser.add_argument('--dataset', type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--max_datasets_size', type=int, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--save_path', type=str, default='output') parser.add_argument("--dataset", type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument("--max_datasets_size", type=int, default=None)
parser.add_argument('--max_epochs', type=int, default=3) parser.add_argument("--save_path", type=str, default="output")
parser.add_argument('--batch_size', type=int, default=4) parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument('--max_len', type=int, default=512) parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") parser.add_argument("--max_len", type=int, default=512)
parser.add_argument('--lr', type=float, default=5e-6) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--accumulation_steps', type=int, default=8) parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument('--use_wandb', default=False, action='store_true') parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument('--grad_checkpoint', default=False, action='store_true') parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)
...@@ -84,28 +84,34 @@ inst = [instructions[0]] * 4 ...@@ -84,28 +84,34 @@ inst = [instructions[0]] * 4
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'pretrained', "pretrained",
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
parser.add_argument('--quant', )
choices=['8bit', '4bit'], parser.add_argument(
default=None, "--quant",
help='Quantization mode. Default: None (no quantization, fp16).') choices=["8bit", "4bit"],
default=None,
help="Quantization mode. Default: None (no quantization, fp16).",
)
parser.add_argument( parser.add_argument(
'--gptq_checkpoint', "--gptq_checkpoint",
default=None, default=None,
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
parser.add_argument('--gptq_group_size', )
type=int, parser.add_argument(
default=128, "--gptq_group_size",
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') type=int,
default=128,
help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
)
args = parser.parse_args() args = parser.parse_args()
if args.quant == '4bit': if args.quant == "4bit":
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained) tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
if args.quant == '4bit': if args.quant == "4bit":
with low_resource_init(): with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained) config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config) model = LlamaForCausalLM(config)
...@@ -114,12 +120,12 @@ if __name__ == "__main__": ...@@ -114,12 +120,12 @@ if __name__ == "__main__":
else: else:
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
args.pretrained, args.pretrained,
load_in_8bit=(args.quant == '8bit'), load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map="auto", device_map="auto",
) )
if args.quant != '8bit': if args.quant != "8bit":
model.half() # seems to fix bugs for some users. model.half() # seems to fix bugs for some users.
model.eval() model.eval()
total_tokens = 0 total_tokens = 0
...@@ -129,7 +135,7 @@ if __name__ == "__main__": ...@@ -129,7 +135,7 @@ if __name__ == "__main__":
resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1) resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1)
total_tokens += tokens total_tokens += tokens
print(f"Response: {resp}") print(f"Response: {resp}")
print('\n----------------------------\n') print("\n----------------------------\n")
duration = time() - start duration = time() - start
print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s') print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s")
print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB') print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
from json import JSONDecodeError
from locust import HttpUser, task from locust import HttpUser, task
samples = [[ samples = [
dict( [
instruction='Who is the best player in the history of NBA?', dict(
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' instruction="Who is the best player in the history of NBA?",
), response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
dict(instruction='continue this talk', response=''), ),
], [ dict(instruction="continue this talk", response=""),
dict(instruction='Who is the best player in the history of NBA?', response=''), ],
]] [
dict(instruction="Who is the best player in the history of NBA?", response=""),
],
]
class GenerationUser(HttpUser): class GenerationUser(HttpUser):
@task @task
def generate(self): def generate(self):
for sample in samples: for sample in samples:
data = {'max_new_tokens': 64, 'history': sample} data = {"max_new_tokens": 64, "history": sample}
with self.client.post('/generate', json=data, catch_response=True) as response: with self.client.post("/generate", json=data, catch_response=True) as response:
if response.status_code in (200, 406): if response.status_code in (200, 406):
response.success() response.success()
else: else:
response.failure('Response wrong') response.failure("Response wrong")
...@@ -16,7 +16,7 @@ from sse_starlette.sse import EventSourceResponse ...@@ -16,7 +16,7 @@ from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
MAX_LEN = 512 MAX_LEN = 512
running_lock = Lock() running_lock = Lock()
...@@ -36,11 +36,11 @@ app.state.limiter = limiter ...@@ -36,11 +36,11 @@ app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# set CORS # set CORS
origin_spec_from_env = os.environ.get('CORS_ORIGIN', None) origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
if origin_spec_from_env is not None: if origin_spec_from_env is not None:
# allow CORS from the specified origins # allow CORS from the specified origins
origins = os.environ['CORS_ORIGIN'].split(',') origins = os.environ["CORS_ORIGIN"].split(",")
else: else:
# allow CORS from all origins # allow CORS from all origins
origins = ["*"] origins = ["*"]
...@@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): ...@@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
# TODO(ver217): streaming generation does not support repetition_penalty now # TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = { model_kwargs = {
'max_generate_tokens': max_new_tokens, "max_generate_tokens": max_new_tokens,
'early_stopping': True, "early_stopping": True,
'top_k': top_k, "top_k": top_k,
'top_p': top_p, "top_p": top_p,
'temperature': temperature, "temperature": temperature,
'prepare_inputs_fn': model.prepare_inputs_for_generation, "prepare_inputs_fn": model.prepare_inputs_for_generation,
'update_model_kwargs_fn': update_model_kwargs_fn, "update_model_kwargs_fn": update_model_kwargs_fn,
} }
is_first_word = True is_first_word = True
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock) generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
...@@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): ...@@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
if is_first_word: if is_first_word:
out_string = out_string.lstrip() out_string = out_string.lstrip()
is_first_word = False is_first_word = False
elif current_sub_tokens[0].startswith('▁'): elif current_sub_tokens[0].startswith("▁"):
# whitespace will be ignored by the frontend # whitespace will be ignored by the frontend
out_string = ' ' + out_string out_string = " " + out_string
yield out_string yield out_string
...@@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator): ...@@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator):
if await request.is_disconnected(): if await request.is_disconnected():
break break
try: try:
yield {'event': 'generate', 'data': next(generator)} yield {"event": "generate", "data": next(generator)}
except StopIteration: except StopIteration:
yield {'event': 'end', 'data': ''} yield {"event": "end", "data": ""}
break break
@app.post('/generate/stream') @app.post("/generate/stream")
@limiter.limit('1/second') @limiter.limit("1/second")
def generate(data: GenerationTaskReq, request: Request): def generate(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
event_source = event_generator( event_source = event_generator(
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)) request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
)
return EventSourceResponse(event_source) return EventSourceResponse(event_source)
@app.post('/generate') @app.post("/generate")
@limiter.limit('1/second') @limiter.limit("1/second")
def generate_no_stream(data: GenerationTaskReq, request: Request): def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt): if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock: with running_lock:
output = model.generate(**inputs, **data.dict(exclude={'history'})) output = model.generate(**inputs, **data.dict(exclude={"history"}))
output = output.cpu() output = output.cpu()
prompt_len = inputs['input_ids'].size(1) prompt_len = inputs["input_ids"].size(1)
response = output[0, prompt_len:] response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True) out_string = tokenizer.decode(response, skip_special_tokens=True)
out_string = prompt_processor.postprocess_output(out_string) out_string = prompt_processor.postprocess_output(out_string)
...@@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): ...@@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
return out_string return out_string
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'pretrained', "pretrained",
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
parser.add_argument('--quant', )
choices=['8bit', '4bit'],
default=None,
help='Quantization mode. Default: None (no quantization, fp16).')
parser.add_argument( parser.add_argument(
'--gptq_checkpoint', "--quant",
choices=["8bit", "4bit"],
default=None, default=None,
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') help="Quantization mode. Default: None (no quantization, fp16).",
parser.add_argument('--gptq_group_size', )
type=int, parser.add_argument(
default=128, "--gptq_checkpoint",
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') default=None,
parser.add_argument('--http_host', default='0.0.0.0') help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
parser.add_argument('--http_port', type=int, default=7070) )
parser.add_argument('--profanity_file', parser.add_argument(
default=None, "--gptq_group_size",
help='Path to profanity words list. It should be a JSON file containing a list of words.') type=int,
default=128,
help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
)
parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=7070)
parser.add_argument(
"--profanity_file",
default=None,
help="Path to profanity words list. It should be a JSON file containing a list of words.",
)
args = parser.parse_args() args = parser.parse_args()
if args.quant == '4bit': if args.quant == "4bit":
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained) tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
...@@ -161,7 +170,7 @@ if __name__ == '__main__': ...@@ -161,7 +170,7 @@ if __name__ == '__main__':
censored_words = [] censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
if args.quant == '4bit': if args.quant == "4bit":
with low_resource_init(): with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained) config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config) model = LlamaForCausalLM(config)
...@@ -170,12 +179,12 @@ if __name__ == '__main__': ...@@ -170,12 +179,12 @@ if __name__ == '__main__':
else: else:
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
args.pretrained, args.pretrained,
load_in_8bit=(args.quant == '8bit'), load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map="auto", device_map="auto",
) )
if args.quant != '8bit': if args.quant != "8bit":
model.half() # seems to fix bugs for some users. model.half() # seems to fix bugs for some users.
model.eval() model.eval()
config = uvicorn.Config(app, host=args.http_host, port=args.http_port) config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
......
...@@ -3,41 +3,49 @@ import os ...@@ -3,41 +3,49 @@ import os
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils import ChatPromptProcessor, Dialogue from utils import ChatPromptProcessor, Dialogue
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH']) tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"])
samples = [ samples = [
([ (
Dialogue( [
instruction='Who is the best player in the history of NBA?', Dialogue(
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' instruction="Who is the best player in the history of NBA?",
), response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
Dialogue(instruction='continue this talk', response=''), ),
], 128, Dialogue(instruction="continue this talk", response=""),
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' ],
128,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
), ),
([ (
Dialogue( [
instruction='Who is the best player in the history of NBA?', Dialogue(
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' instruction="Who is the best player in the history of NBA?",
), response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
Dialogue(instruction='continue this talk', response=''), ),
], 200, Dialogue(instruction="continue this talk", response=""),
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' ],
200,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
), ),
([ (
Dialogue( [
instruction='Who is the best player in the history of NBA?', Dialogue(
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' instruction="Who is the best player in the history of NBA?",
), response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
Dialogue(instruction='continue this talk', response=''), ),
], 211, Dialogue(instruction="continue this talk", response=""),
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' ],
211,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n",
), ),
([ (
Dialogue(instruction='Who is the best player in the history of NBA?', response=''), [
], 128, Dialogue(instruction="Who is the best player in the history of NBA?", response=""),
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' ],
128,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n",
), ),
] ]
...@@ -49,5 +57,5 @@ def test_chat_prompt_processor(): ...@@ -49,5 +57,5 @@ def test_chat_prompt_processor():
assert prompt == result assert prompt == result
if __name__ == '__main__': if __name__ == "__main__":
test_chat_prompt_processor() test_chat_prompt_processor()
...@@ -20,9 +20,9 @@ except ImportError: ...@@ -20,9 +20,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(top_k: Optional[int] = None, def prepare_logits_processor(
top_p: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
temperature: Optional[float] = None) -> LogitsProcessorList: ) -> LogitsProcessorList:
processor_list = LogitsProcessorList() processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0: if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature)) processor_list.append(TemperatureLogitsWarper(temperature))
...@@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: ...@@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0 return unfinished_sequences.max() == 0
def sample_streamingly(model: nn.Module, def sample_streamingly(
input_ids: torch.Tensor, model: nn.Module,
max_generate_tokens: int, input_ids: torch.Tensor,
early_stopping: bool = False, max_generate_tokens: int,
eos_token_id: Optional[int] = None, early_stopping: bool = False,
pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
top_k: Optional[int] = None, pad_token_id: Optional[int] = None,
top_p: Optional[float] = None, top_k: Optional[int] = None,
temperature: Optional[float] = None, top_p: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, temperature: Optional[float] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> Generator: update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> Generator:
logits_processor = prepare_logits_processor(top_k, top_p, temperature) logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(max_generate_tokens): for _ in range(max_generate_tokens):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { model_inputs = (
'input_ids': input_ids prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
} )
outputs = model(**model_inputs) outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :] next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution # pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits) next_token_logits = logits_processor(input_ids, next_token_logits)
# sample # sample
...@@ -107,25 +108,26 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: ...@@ -107,25 +108,26 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "attention_mask" in model_kwargs: if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"] attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat( model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return model_kwargs return model_kwargs
class Dialogue(BaseModel): class Dialogue(BaseModel):
instruction: str = Field(min_length=1, example='Count up from 1 to 500.') instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
response: str = Field(example='') response: str = Field(example="")
def _format_dialogue(instruction: str, response: str = ''): def _format_dialogue(instruction: str, response: str = ""):
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}' return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
class ChatPromptProcessor: class ChatPromptProcessor:
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []): def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -138,42 +140,48 @@ class ChatPromptProcessor: ...@@ -138,42 +140,48 @@ class ChatPromptProcessor:
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str: def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
if self.context_len is None: if self.context_len is None:
self.context_len = len(self.tokenizer(self.context)['input_ids']) self.context_len = len(self.tokenizer(self.context)["input_ids"])
if self.dialogue_placeholder_len is None: if self.dialogue_placeholder_len is None:
self.dialogue_placeholder_len = len( self.dialogue_placeholder_len = len(
self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids']) self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
)
prompt = self.context prompt = self.context
# the last dialogue must be in the prompt # the last dialogue must be in the prompt
last_dialogue = history.pop() last_dialogue = history.pop()
# the response of the last dialogue is empty # the response of the last dialogue is empty
assert last_dialogue.response == '' assert last_dialogue.response == ""
if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False) if (
['input_ids']) + max_new_tokens + self.context_len >= self.max_len: len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
+ max_new_tokens
+ self.context_len
>= self.max_len
):
# to avoid truncate placeholder, apply truncate to the original instruction # to avoid truncate placeholder, apply truncate to the original instruction
instruction_truncated = self.tokenizer(last_dialogue.instruction, instruction_truncated = self.tokenizer(
add_special_tokens=False, last_dialogue.instruction,
truncation=True, add_special_tokens=False,
max_length=(self.max_len - max_new_tokens - self.context_len - truncation=True,
self.dialogue_placeholder_len))['input_ids'] max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
)["input_ids"]
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip() instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
prompt += _format_dialogue(instruction_truncated) prompt += _format_dialogue(instruction_truncated)
return prompt return prompt
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids']) res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
rows = [] rows = []
for dialogue in history[::-1]: for dialogue in history[::-1]:
text = _format_dialogue(dialogue.instruction, dialogue.response) text = _format_dialogue(dialogue.instruction, dialogue.response)
cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids']) cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
if res_len - cur_len < 0: if res_len - cur_len < 0:
break break
res_len -= cur_len res_len -= cur_len
rows.insert(0, text) rows.insert(0, text)
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction) prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt return prompt
def postprocess_output(self, output: str) -> str: def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub('', output) output = STOP_PAT.sub("", output)
return output.strip() return output.strip()
def has_censored_words(self, text: str) -> bool: def has_censored_words(self, text: str) -> bool:
...@@ -184,7 +192,6 @@ class ChatPromptProcessor: ...@@ -184,7 +192,6 @@ class ChatPromptProcessor:
class LockedIterator: class LockedIterator:
def __init__(self, it, lock: Lock) -> None: def __init__(self, it, lock: Lock) -> None:
self.lock = lock self.lock = lock
self.it = iter(it) self.it = iter(it)
......
pytest pytest
colossalai==0.3.1 colossalai==0.3.1
\ No newline at end of file
...@@ -2,40 +2,42 @@ from setuptools import find_packages, setup ...@@ -2,40 +2,42 @@ from setuptools import find_packages, setup
def fetch_requirements(path): def fetch_requirements(path):
with open(path, 'r') as fd: with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()] return [r.strip() for r in fd.readlines()]
def fetch_readme(): def fetch_readme():
with open('README.md', encoding='utf-8') as f: with open("README.md", encoding="utf-8") as f:
return f.read() return f.read()
def fetch_version(): def fetch_version():
with open('version.txt', 'r') as f: with open("version.txt", "r") as f:
return f.read().strip() return f.read().strip()
setup( setup(
name='coati', name="coati",
version=fetch_version(), version=fetch_version(),
packages=find_packages(exclude=( packages=find_packages(
'tests', exclude=(
'benchmarks', "tests",
'*.egg-info', "benchmarks",
)), "*.egg-info",
description='Colossal-AI Talking Intelligence', )
),
description="Colossal-AI Talking Intelligence",
long_description=fetch_readme(), long_description=fetch_readme(),
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
license='Apache Software License 2.0', license="Apache Software License 2.0",
url='https://github.com/hpcaitech/Coati', url="https://github.com/hpcaitech/Coati",
install_requires=fetch_requirements('requirements.txt'), install_requires=fetch_requirements("requirements.txt"),
python_requires='>=3.6', python_requires=">=3.6",
classifiers=[ classifiers=[
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'License :: OSI Approved :: Apache Software License', "License :: OSI Approved :: Apache Software License",
'Environment :: GPU :: NVIDIA CUDA', "Environment :: GPU :: NVIDIA CUDA",
'Topic :: Scientific/Engineering :: Artificial Intelligence', "Topic :: Scientific/Engineering :: Artificial Intelligence",
'Topic :: System :: Distributed Computing', "Topic :: System :: Distributed Computing",
], ],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment