Unverified Commit 7b9b8644 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[chat]: update rm, add wandb and fix bugs (#4471)



* feat: modify forward fn of critic and reward model

* feat: modify calc_action_log_probs

* to: add wandb in sft and rm trainer

* feat: update train_sft

* feat: update train_rm

* style: modify type annotation and add warning

* feat: pass tokenizer to ppo trainer

* to: modify trainer base and maker base

* feat: add wandb in ppo trainer

* feat: pass tokenizer to generate

* test: update generate fn tests

* test: update train tests

* fix: remove action_mask

* feat: remove unused code

* fix: fix wrong ignore_index

* fix: fix mock tokenizer

* chore: update requirements

* revert: modify make_experience

* fix: fix inference

* fix: add padding side

* style: modify _on_learn_batch_end

* test: use mock tokenizer

* fix: use bf16 to avoid overflow

* fix: fix workflow

* [chat] fix gemini strategy

* [chat] fix

* sync: update colossalai strategy

* fix: fix args and model dtype

* fix: fix checkpoint test

* fix: fix requirements

* fix: fix missing import and wrong arg

* fix: temporarily skip gemini test in stage 3

* style: apply pre-commit

* fix: temporarily skip gemini test in stage 1&2

---------
Co-authored-by: default avatarMingyan Jiang <1829166702@qq.com>
parent 07c2e3d0
...@@ -45,9 +45,17 @@ def eval(args): ...@@ -45,9 +45,17 @@ def eval(args):
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
actor.eval() actor.eval()
tokenizer.padding_side = "left"
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device()) input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
outputs = generate( outputs = generate(
actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1 actor,
input_ids,
tokenizer=tokenizer,
max_length=args.max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
) )
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(f"[Output]: {''.join(output)}") print(f"[Output]: {''.join(output)}")
......
pandas>=1.4.1 pandas>=1.4.1
sentencepiece sentencepiece
colossalai==0.3.1 colossalai>=0.3.1
...@@ -23,7 +23,7 @@ def main(args): ...@@ -23,7 +23,7 @@ def main(args):
if args.strategy == "ddp": if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
...@@ -65,8 +65,8 @@ def main(args): ...@@ -65,8 +65,8 @@ def main(args):
if args.rm_path is not None: if args.rm_path is not None:
reward_model.load_state_dict(state_dict, strict=False) reward_model.load_state_dict(state_dict, strict=False)
initial_model.to(torch.float16).to(torch.cuda.current_device()) initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
if args.model == "gpt2": if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
...@@ -80,13 +80,13 @@ def main(args): ...@@ -80,13 +80,13 @@ def main(args):
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == "gpt2": if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == "bloom": elif rm_model_name == "bloom":
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == "opt": elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == "llama": elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
...@@ -94,17 +94,16 @@ def main(args): ...@@ -94,17 +94,16 @@ def main(args):
critic.load_state_dict(state_dict, strict=False) critic.load_state_dict(state_dict, strict=False)
del state_dict del state_dict
if args.strategy != "colossalai_gemini": actor.to(torch.bfloat16).to(torch.cuda.current_device())
critic.to(torch.float16).to(torch.cuda.current_device()) critic.to(torch.bfloat16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer # configure optimizer
if args.strategy.startswith("colossalai"): if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7) actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
else: else:
actor_optim = Adam(actor.parameters(), lr=1e-7) actor_optim = Adam(actor.parameters(), lr=args.lr)
critic_optim = Adam(critic.parameters(), lr=1e-7) critic_optim = Adam(critic.parameters(), lr=args.lr)
# configure tokenizer # configure tokenizer
if args.model == "gpt2": if args.model == "gpt2":
...@@ -126,8 +125,15 @@ def main(args): ...@@ -126,8 +125,15 @@ def main(args):
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
# NOTE: generate() requires padding_side to be "left"
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384) tokenizer.padding_side = "left"
prompt_dataset = PromptDataset(
tokenizer=tokenizer,
data_path=args.prompt_dataset,
max_datasets_size=args.max_datasets_size,
max_length=args.max_input_len,
)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else: else:
...@@ -137,7 +143,10 @@ def main(args): ...@@ -137,7 +143,10 @@ def main(args):
) )
pretrain_dataset = SupervisedDataset( pretrain_dataset = SupervisedDataset(
tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len tokenizer=tokenizer,
data_path=args.pretrain_dataset,
max_datasets_size=args.max_datasets_size,
max_length=args.max_input_len,
) )
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
...@@ -161,6 +170,7 @@ def main(args): ...@@ -161,6 +170,7 @@ def main(args):
initial_model, initial_model,
actor_optim, actor_optim,
critic_optim, critic_optim,
tokenizer=tokenizer,
kl_coef=args.kl_coef, kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef, ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
...@@ -169,17 +179,17 @@ def main(args): ...@@ -169,17 +179,17 @@ def main(args):
do_sample=True, do_sample=True,
temperature=1.0, temperature=1.0,
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
offload_inference_models=args.strategy != "colossalai_gemini", offload_inference_models=args.strategy != "colossalai_gemini",
) )
trainer.fit( trainer.fit(
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps, num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps, num_update_steps=args.num_update_steps,
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
) )
# save model checkpoint after fitting # save model checkpoint after fitting
...@@ -195,6 +205,7 @@ if __name__ == "__main__": ...@@ -195,6 +205,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset") parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument("--max_datasets_size", type=int, default=50000)
parser.add_argument( parser.add_argument(
"--strategy", "--strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2"], choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
...@@ -216,9 +227,12 @@ if __name__ == "__main__": ...@@ -216,9 +227,12 @@ if __name__ == "__main__":
parser.add_argument("--ptx_batch_size", type=int, default=1) parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--lr", type=float, default=1e-7)
parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9) parser.add_argument("--ptx_coef", type=float, default=0.9)
parser.add_argument("--max_input_len", type=int, default=96) parser.add_argument("--max_input_len", type=int, default=96)
parser.add_argument("--max_seq_len", type=int, default=128) parser.add_argument("--max_seq_len", type=int, default=128)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
import argparse import argparse
from random import randint
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -27,7 +26,7 @@ def train(args): ...@@ -27,7 +26,7 @@ def train(args):
if args.strategy == "ddp": if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda") strategy = GeminiStrategy(placement_policy="auto")
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
...@@ -46,7 +45,7 @@ def train(args): ...@@ -46,7 +45,7 @@ def train(args):
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
model.to(torch.float16).to(torch.cuda.current_device()) model.to(torch.bfloat16).to(torch.cuda.current_device())
if args.model_path is not None: if args.model_path is not None:
state_dict = torch.load(args.model_path) state_dict = torch.load(args.model_path)
...@@ -75,9 +74,9 @@ def train(args): ...@@ -75,9 +74,9 @@ def train(args):
# configure optimizer # configure optimizer
if args.strategy.startswith("colossalai"): if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=5e-6) optim = HybridAdam(model.parameters(), lr=args.lr)
else: else:
optim = Adam(model.parameters(), lr=5e-6) optim = Adam(model.parameters(), lr=args.lr)
# configure loss function # configure loss function
if args.loss_fn == "log_sig": if args.loss_fn == "log_sig":
...@@ -93,21 +92,14 @@ def train(args): ...@@ -93,21 +92,14 @@ def train(args):
else: else:
data = load_dataset(args.dataset) data = load_dataset(args.dataset)
if args.test: train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
train_data = data["train"].select(range(20)) eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
eval_data = data["test"].select(range(5))
else:
train_data = data["train"]
eval_data = data["test"]
valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
if args.dataset == "Dahoas/rm-static": if args.dataset == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
elif args.dataset == "Anthropic/hh-rlhf": elif args.dataset == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
else: else:
raise ValueError(f'Unsupported dataset "{args.dataset}"') raise ValueError(f'Unsupported dataset "{args.dataset}"')
...@@ -121,14 +113,6 @@ def train(args): ...@@ -121,14 +113,6 @@ def train(args):
rank=dist.get_rank(), rank=dist.get_rank(),
num_replicas=dist.get_world_size(), num_replicas=dist.get_world_size(),
) )
valid_sampler = DistributedSampler(
valid_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
eval_sampler = DistributedSampler( eval_sampler = DistributedSampler(
eval_dataset, eval_dataset,
shuffle=True, shuffle=True,
...@@ -139,7 +123,6 @@ def train(args): ...@@ -139,7 +123,6 @@ def train(args):
) )
else: else:
train_sampler = None train_sampler = None
valid_sampler = None
eval_sampler = None eval_sampler = None
train_dataloader = DataLoader( train_dataloader = DataLoader(
...@@ -150,14 +133,6 @@ def train(args): ...@@ -150,14 +133,6 @@ def train(args):
pin_memory=True, pin_memory=True,
) )
valid_dataloader = DataLoader(
valid_dataset,
shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size,
pin_memory=True,
)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
) )
...@@ -176,7 +151,12 @@ def train(args): ...@@ -176,7 +151,12 @@ def train(args):
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
) )
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader) trainer.fit(
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True) strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
...@@ -200,12 +180,15 @@ if __name__ == "__main__": ...@@ -200,12 +180,15 @@ if __name__ == "__main__":
"--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static" "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
) )
parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None) parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
parser.add_argument("--max_datasets_size", type=int, default=1000000)
parser.add_argument("--save_path", type=str, default="rm_ckpt") parser.add_argument("--save_path", type=str, default="rm_ckpt")
parser.add_argument("--max_epochs", type=int, default=1) parser.add_argument("--max_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument("--test", type=bool, default=False) parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)
...@@ -16,7 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { ...@@ -16,7 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2 set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py \ torchrun --standalone --nproc_per_node=2 train_reward_model.py \
--model 'bloom' \ --pretrain 'gpt2' \
--model 'gpt2' \
--strategy colossalai_zero2 \ --strategy colossalai_zero2 \
--loss_fn 'log_sig' \ --loss_fn 'log_exp' \
--dataset 'Anthropic/hh-rlhf' --dataset 'Anthropic/hh-rlhf' \
--batch_size 16 \
--max_epochs 10
...@@ -23,7 +23,6 @@ from transformers.trainer import get_scheduler ...@@ -23,7 +23,6 @@ from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter
def train(args): def train(args):
...@@ -31,7 +30,7 @@ def train(args): ...@@ -31,7 +30,7 @@ def train(args):
if args.strategy == "ddp": if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda") strategy = GeminiStrategy(placement_policy="auto")
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu": elif args.strategy == "colossalai_zero2_cpu":
...@@ -57,7 +56,7 @@ def train(args): ...@@ -57,7 +56,7 @@ def train(args):
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
model.to(torch.float16).to(torch.cuda.current_device()) model.to(torch.bfloat16).to(torch.cuda.current_device())
# configure tokenizer # configure tokenizer
if args.model == "gpt2": if args.model == "gpt2":
...@@ -84,28 +83,21 @@ def train(args): ...@@ -84,28 +83,21 @@ def train(args):
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
sub_module_name = ".".join(name.split(".")[:-1])
weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer # configure optimizer
if args.strategy.startswith("colossalai"): if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else: else:
optim = Adam(model.parameters(), lr=args.lr) optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
# configure dataset # configure dataset
if args.dataset == "yizhongw/self_instruct": if args.dataset == "yizhongw/self_instruct":
train_data = load_dataset(args.dataset, "super_natural_instructions", split="train") train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test") eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
if args.max_datasets_size is not None:
train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))
train_dataset = SFTDataset(train_data, tokenizer, args.max_len) train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
...@@ -176,8 +168,13 @@ def train(args): ...@@ -176,8 +168,13 @@ def train(args):
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
) )
logger = get_dist_logger()
trainer.fit( trainer.fit(
train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
logger=logger,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
) )
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
...@@ -207,9 +204,9 @@ if __name__ == "__main__": ...@@ -207,9 +204,9 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ ...@@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \ --pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \ --model 'llama' \
--strategy colossalai_zero2 \ --strategy colossalai_zero2 \
--log_interval 10 \
--save_path /path/to/Coati-7B \ --save_path /path/to/Coati-7B \
--dataset /path/to/data.json \ --dataset /path/to/data.json \
--batch_size 4 \ --batch_size 4 \
......
pytest pytest
colossalai==0.3.1 colossalai>=0.3.1
...@@ -2,7 +2,7 @@ transformers>=4.20.1 ...@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm tqdm
datasets datasets
loralib loralib
colossalai==0.3.1 colossalai>=0.3.1
torch<2.0.0, >=1.12.1 torch<2.0.0, >=1.12.1
langchain langchain
tokenizers tokenizers
...@@ -11,3 +11,4 @@ sse_starlette ...@@ -11,3 +11,4 @@ sse_starlette
wandb wandb
sentencepiece sentencepiece
gpustat gpustat
tensorboard
...@@ -25,8 +25,8 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: ...@@ -25,8 +25,8 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict:
def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8): def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
data = get_data(batch_size) data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
actor_output = actor(data["input_ids"], data["attention_mask"]) actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1)) action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
loss = action_log_probs.sum() loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim) strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim) strategy.optimizer_step(actor_optim)
...@@ -36,7 +36,7 @@ def run_test_checkpoint(strategy_name: str, shard: bool): ...@@ -36,7 +36,7 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
if strategy_name == "ddp": if strategy_name == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini": elif strategy_name == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
elif strategy_name == "colossalai_zero2": elif strategy_name == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
......
...@@ -226,7 +226,9 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: ...@@ -226,7 +226,9 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
check_content(input_ids.masked_select(attention_mask), tokenizer, model) check_content(input_ids.masked_select(attention_mask), tokenizer, model)
assert torch.all(attention_mask) assert torch.all(attention_mask)
ignore_mask = labels == IGNORE_INDEX ignore_mask = labels == IGNORE_INDEX
check_content(input_ids.masked_select(ignore_mask), tokenizer, model) prompt_mask = torch.logical_and(ignore_mask, attention_mask)
check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
if __name__ == "__main__": if __name__ == "__main__":
......
import copy
import os import os
from copy import deepcopy
import pytest import pytest
import torch import torch
...@@ -8,6 +8,7 @@ from coati.experience_buffer import NaiveExperienceBuffer ...@@ -8,6 +8,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import NaiveExperienceMaker from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic from coati.models.gpt import GPTActor, GPTCritic
from coati.trainer.ppo import _set_default_generate_kwargs
from coati.trainer.strategies import DDPStrategy, GeminiStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
...@@ -42,27 +43,38 @@ def make_and_consume_experience(strategy): ...@@ -42,27 +43,38 @@ def make_and_consume_experience(strategy):
elif strategy == "colossalai-zero2": elif strategy == "colossalai-zero2":
strategy = LowLevelZeroStrategy() strategy = LowLevelZeroStrategy()
elif strategy == "colossalai-gemini": elif strategy == "colossalai-gemini":
strategy = GeminiStrategy(placement_policy="cuda") strategy = GeminiStrategy(placement_policy="static")
else: else:
raise ValueError(f'Unsupported strategy "{strategy}"') raise ValueError(f'Unsupported strategy "{strategy}"')
actor = GPTActor(config=GPT_CONFIG).cuda() with strategy.model_init_context():
critic = GPTCritic(config=GPT_CONFIG).cuda() actor = GPTActor(config=GPT_CONFIG).cuda()
critic = GPTCritic(config=GPT_CONFIG).cuda()
initial_model = deepcopy(actor) initial_model = GPTActor(config=GPT_CONFIG).cuda()
reward_model = RewardModel(deepcopy(critic.model)).cuda() reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
class MockTokenizer:
def __init__(self):
self.padding_side = "left"
self.eos_token_id = 0
self.pad_token_id = 0
tokenizer = MockTokenizer()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
generate_kwargs = dict(do_sample=True, max_length=16)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
# experience of all ranks should be the same # experience of all ranks should be the same
for _ in range(2): for _ in range(2):
data = get_data(EXPERIENCE_BATCH_SIZE) data = get_data(EXPERIENCE_BATCH_SIZE)
assert gather_and_equal(data["input_ids"]) assert gather_and_equal(data["input_ids"])
assert gather_and_equal(data["attention_mask"]) assert gather_and_equal(data["attention_mask"])
experience = experience_maker.make_experience( experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
**data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
)
assert gather_and_equal(experience.sequences) assert gather_and_equal(experience.sequences)
assert gather_and_equal(experience.action_log_probs) assert gather_and_equal(experience.action_log_probs)
assert gather_and_equal(experience.values) assert gather_and_equal(experience.values)
...@@ -115,4 +127,4 @@ def test_experience(world_size, strategy): ...@@ -115,4 +127,4 @@ def test_experience(world_size, strategy):
if __name__ == "__main__": if __name__ == "__main__":
test_experience(2, "colossalai") test_experience(2, "colossalai-zero2")
...@@ -14,7 +14,7 @@ from coati.models.llama import LlamaActor ...@@ -14,7 +14,7 @@ from coati.models.llama import LlamaActor
from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.utils import calc_action_log_probs, masked_mean
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -27,7 +27,6 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea ...@@ -27,7 +27,6 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: LlamaActor(), # lambda: LlamaActor(),
lambda: OPTActor(), lambda: OPTActor(),
# lambda: ChatGLMActor(),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -43,9 +42,16 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea ...@@ -43,9 +42,16 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
], ],
) )
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
class MockTokenizer:
def __init__(self):
self.padding_side = "left"
self.eos_token_id = 0
self.pad_token_id = 0
actor = actor_maker() actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs) tokenizer = MockTokenizer()
sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"]) assert sequences.shape == (batch_size, generate_kwargs["max_length"])
...@@ -55,24 +61,12 @@ def test_utils(): ...@@ -55,24 +61,12 @@ def test_utils():
assert fn_output.dim() == 0 assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0)) assert torch.allclose(fn_output, torch.tensor(1.0))
batch_size = 4
num_labels = 10
fn_input = {
"r": torch.ones((batch_size,)),
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels)),
}
fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size,)
batch_size = 4 batch_size = 4
seq_len = 32 seq_len = 32
num_labels = 10 num_labels = 10
num_actions = 2 num_actions = 2
fn_input = { fn_input = {
"output": {"logits": torch.randn((batch_size, seq_len, num_labels))}, "logits": torch.randn((batch_size, seq_len, num_labels)),
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)), "sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions, "num_actions": num_actions,
} }
...@@ -135,7 +129,6 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b ...@@ -135,7 +129,6 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b
} }
critic_input = { critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)), "sequences": torch.randint(0, 100, (batch_size, seq_len)),
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)), "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
} }
rm_input = { rm_input = {
......
...@@ -24,8 +24,8 @@ if [ -z "$SFT_DATASET" ]; then ...@@ -24,8 +24,8 @@ if [ -z "$SFT_DATASET" ]; then
exit 1 exit 1
fi fi
if [ -z "$PROMPT_PATH" ]; then if [ -z "$PROMPT_DATASET" ]; then
echo "Please set \$PROMPT_PATH to the path to prompts csv." echo "Please set \$PROMPT_DATASET to the path to prompts csv."
exit 1 exit 1
fi fi
...@@ -74,11 +74,15 @@ echo "[Test]: testing sft ..." ...@@ -74,11 +74,15 @@ echo "[Test]: testing sft ..."
# FIXME: This is a hack to skip tests that are not working # FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time # - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=( SKIPPED_TESTS=(
"gpt2-ddp" "gpt2-ddp"
"llama-ddp" "llama-ddp"
"llama-colossalai_gemini" "llama-colossalai_gemini"
"llama-colossalai_zero2" "llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
) )
GRAD_CKPTS=('' '--grad_checkpoint') GRAD_CKPTS=('' '--grad_checkpoint')
...@@ -105,7 +109,7 @@ for lora_rank in '0' '4'; do ...@@ -105,7 +109,7 @@ for lora_rank in '0' '4'; do
$pretrain_model --tokenizer $MODELS_DIR/$model \ $pretrain_model --tokenizer $MODELS_DIR/$model \
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \ --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
--dataset $SFT_DATASET --max_datasets_size 8 \ --dataset $SFT_DATASET --max_datasets_size 8 \
--max_epochs 1 --batch_size 1 --accumulation_steps 1 \ --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
passed=$? passed=$?
if [ $passed -eq 0 ]; then if [ $passed -eq 0 ]; then
...@@ -125,11 +129,15 @@ echo "[Test]: testing reward model ..." ...@@ -125,11 +129,15 @@ echo "[Test]: testing reward model ..."
# FIXME: This is a hack to skip tests that are not working # FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time # - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=( SKIPPED_TESTS=(
"gpt2-ddp" "gpt2-ddp"
"llama-ddp" "llama-ddp"
"llama-colossalai_gemini" "llama-colossalai_gemini"
"llama-colossalai_zero2" "llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
) )
LOSS_FNS=('log_sig' 'log_exp') LOSS_FNS=('log_sig' 'log_exp')
...@@ -157,8 +165,9 @@ for lora_rank in '0' '4'; do ...@@ -157,8 +165,9 @@ for lora_rank in '0' '4'; do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i" echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
$pretrain_model --tokenizer $MODELS_DIR/$model \ $pretrain_model --tokenizer $MODELS_DIR/$model \
--model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \ --dataset $dataset --subset $subset --max_datasets_size 8 \
--dataset $dataset --subset $subset --test True --batch_size 1 \ --model $model --strategy $strategy --lora_rank $lora_rank \
--loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
passed=$? passed=$?
if [ $passed -eq 0 ]; then if [ $passed -eq 0 ]; then
...@@ -178,11 +187,15 @@ echo "[Test]: testing RLHF ..." ...@@ -178,11 +187,15 @@ echo "[Test]: testing RLHF ..."
# FIXME: This is a hack to skip tests that are not working # FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation # - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time # - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=( SKIPPED_TESTS=(
"gpt2-ddp" "gpt2-ddp"
"llama-ddp" "llama-ddp"
"llama-colossalai_gemini" "llama-colossalai_gemini"
"llama-colossalai_zero2" "llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
) )
for model in ${MODELS[@]}; do for model in ${MODELS[@]}; do
...@@ -204,9 +217,9 @@ for model in ${MODELS[@]}; do ...@@ -204,9 +217,9 @@ for model in ${MODELS[@]}; do
for i in $(seq $NUM_RETRY); do for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i" echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \ --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \ --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \ --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \ --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \ $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
......
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
import pytest import pytest
import torch import torch
from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import colossalai import colossalai
...@@ -13,7 +14,6 @@ from colossalai.fx.profiler import parameter_size ...@@ -13,7 +14,6 @@ from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from model_zoo import GPTLMLoss, get_gpt2_components
def parse_args(): def parse_args():
......
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
from functools import partial from functools import partial
import torch import torch
from model_zoo import model_builder
from torch import nn from torch import nn
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
...@@ -12,7 +13,6 @@ from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology ...@@ -12,7 +13,6 @@ from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine
from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.legacy.pipeline.rpc.utils import rpc_run
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from model_zoo import model_builder
def parse_args(): def parse_args():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment