Commit 739cfe33 authored by zhang-yi-chi's avatar zhang-yi-chi
Browse files

[chat] fix enable single gpu training bug

parent d7bf2847
...@@ -8,7 +8,7 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic ...@@ -8,7 +8,7 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
from coati.trainer import PPOTrainer from coati.trainer import PPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding from coati.utils import prepare_llama_tokenizer_and_embedding
...@@ -143,6 +143,8 @@ def main(args): ...@@ -143,6 +143,8 @@ def main(args):
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384) prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset, prompt_dataloader = DataLoader(prompt_dataset,
shuffle=(prompt_sampler is None), shuffle=(prompt_sampler is None),
sampler=prompt_sampler, sampler=prompt_sampler,
...@@ -151,6 +153,8 @@ def main(args): ...@@ -151,6 +153,8 @@ def main(args):
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384) pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset, pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None), shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler, sampler=pretrain_sampler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment