Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
739cfe33
"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "641b1ee71a19e2337f3363620b228dd355835b04"
Commit
739cfe33
authored
Apr 22, 2023
by
zhang-yi-chi
Browse files
[chat] fix enable single gpu training bug
parent
d7bf2847
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
applications/Chat/examples/train_prompts.py
applications/Chat/examples/train_prompts.py
+5
-1
No files found.
applications/Chat/examples/train_prompts.py
View file @
739cfe33
...
...
@@ -8,7 +8,7 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from
coati.models.gpt
import
GPTRM
,
GPTActor
,
GPTCritic
from
coati.models.llama
import
LlamaActor
,
LlamaCritic
,
LlamaRM
from
coati.models.opt
import
OPTRM
,
OPTActor
,
OPTCritic
from
coati.models.roberta
import
RoBERTaRM
,
RoBERTaActor
,
RoBERTaCritic
from
coati.models.roberta
import
RoBERTaActor
,
RoBERTaCritic
,
RoBERTaRM
from
coati.trainer
import
PPOTrainer
from
coati.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
coati.utils
import
prepare_llama_tokenizer_and_embedding
...
...
@@ -143,6 +143,8 @@ def main(args):
prompt_dataset
=
PromptDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
prompt_path
,
max_datasets_size
=
16384
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
prompt_sampler
=
DistributedSampler
(
prompt_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
else
:
prompt_sampler
=
None
prompt_dataloader
=
DataLoader
(
prompt_dataset
,
shuffle
=
(
prompt_sampler
is
None
),
sampler
=
prompt_sampler
,
...
...
@@ -151,6 +153,8 @@ def main(args):
pretrain_dataset
=
SupervisedDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
pretrain_dataset
,
max_datasets_size
=
16384
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
pretrain_sampler
=
DistributedSampler
(
pretrain_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
else
:
pretrain_sampler
=
None
pretrain_dataloader
=
DataLoader
(
pretrain_dataset
,
shuffle
=
(
pretrain_sampler
is
None
),
sampler
=
pretrain_sampler
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment