Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8257e105
Unverified
Commit
8257e105
authored
Mar 29, 2023
by
BlueRum
Committed by
GitHub
Mar 29, 2023
Browse files
[chat]polish prompts training (#3300)
* polish train_prompts * polish readme
parent
62f71561
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
39 deletions
+55
-39
applications/Chat/examples/README.md
applications/Chat/examples/README.md
+1
-0
applications/Chat/examples/train_prompts.py
applications/Chat/examples/train_prompts.py
+54
-39
No files found.
applications/Chat/examples/README.md
View file @
8257e105
...
@@ -125,6 +125,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
...
@@ -125,6 +125,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
-
--strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive'
-
--strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive'
-
--model: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-
--model: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-
--pretrain: pretrain model, type=str, default=None
-
--pretrain: pretrain model, type=str, default=None
-
--rm_model: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
-
--rm_pretrain: pretrain model for reward model, type=str, default=None
-
--rm_pretrain: pretrain model for reward model, type=str, default=None
-
--rm_path: the path of rm model, type=str, default=None
-
--rm_path: the path of rm model, type=str, default=None
-
--save_path: path to save the model, type=str, default='output'
-
--save_path: path to save the model, type=str, default='output'
...
...
applications/Chat/examples/train_prompts.py
View file @
8257e105
import
argparse
import
argparse
import
pandas
as
pd
import
pandas
as
pd
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
coati.dataset
import
DataCollatorForSupervisedDataset
,
PromptDataset
,
SupervisedDataset
from
coati.models.bloom
import
BLOOMActor
,
BLOOMRM
,
BLOOMCritic
from
coati.models.bloom
import
BLOOMRM
,
BLOOMActor
,
BLOOMCritic
from
coati.models.gpt
import
GPTActor
,
GPTRM
,
GPTCritic
from
coati.models.gpt
import
GPTRM
,
GPTActor
,
GPTCritic
from
coati.models.opt
import
OPTActor
,
OPTRM
,
OPTCritic
from
coati.models.llama
import
LlamaActor
from
coati.models.llama
import
LlamaActor
,
LlamaRM
,
LlamaCritic
from
coati.models.opt
import
OPTRM
,
OPTActor
,
OPTCritic
from
coati.trainer
import
PPOTrainer
from
coati.trainer
import
PPOTrainer
from
coati.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
coati.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
coati.utils
import
prepare_llama_tokenizer_and_embedding
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
GPT2Tokenizer
,
LlamaTokenizer
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
LlamaTokenizer
,
GPT2Tokenizer
from
coati.dataset
import
SupervisedDataset
,
DataCollatorForSupervisedDataset
,
PromptDataset
from
coati.utils
import
prepare_llama_tokenizer_and_embedding
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
...
@@ -38,44 +37,66 @@ def main(args):
...
@@ -38,44 +37,66 @@ def main(args):
# configure model
# configure model
if
args
.
model
==
'gpt2'
:
if
args
.
model
==
'gpt2'
:
initial_model
=
GPTActor
(
pretrained
=
args
.
pretrain
)
initial_model
=
GPTActor
(
pretrained
=
args
.
pretrain
)
reward_model
=
GPTRM
(
pretrained
=
args
.
rm_pretrain
)
elif
args
.
model
==
'bloom'
:
elif
args
.
model
==
'bloom'
:
initial_model
=
BLOOMActor
(
pretrained
=
args
.
pretrain
)
initial_model
=
BLOOMActor
(
pretrained
=
args
.
pretrain
)
reward_model
=
BLOOMRM
(
pretrained
=
args
.
rm_pretrain
)
elif
args
.
model
==
'opt'
:
elif
args
.
model
==
'opt'
:
initial_model
=
OPTActor
(
pretrained
=
args
.
pretrain
)
initial_model
=
OPTActor
(
pretrained
=
args
.
pretrain
)
reward_model
=
OPTRM
(
pretrained
=
args
.
rm_pretrain
)
elif
args
.
model
==
'llama'
:
elif
args
.
model
==
'llama'
:
initial_model
=
LlamaActor
(
pretrained
=
args
.
pretrain
)
initial_model
=
LlamaActor
(
pretrained
=
args
.
pretrain
)
else
:
raise
ValueError
(
f
'Unsupported actor model "
{
args
.
model
}
"'
)
if
args
.
rm_model
==
None
:
rm_model_name
=
args
.
model
else
:
rm_model_name
=
args
.
rm_model
if
rm_model_name
==
'gpt2'
:
reward_model
=
GPTRM
(
pretrained
=
args
.
rm_pretrain
)
elif
rm_model_name
==
'bloom'
:
reward_model
=
BLOOMRM
(
pretrained
=
args
.
rm_pretrain
)
reward_model
=
BLOOMRM
(
pretrained
=
args
.
rm_pretrain
)
elif
rm_model_name
==
'opt'
:
reward_model
=
OPTRM
(
pretrained
=
args
.
rm_pretrain
)
elif
rm_model_name
==
'llama'
:
reward_model
=
LlamaRM
(
pretrained
=
args
.
rm_pretrain
)
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported reward model "
{
rm_model_name
}
"'
)
if
args
.
rm_path
is
not
None
:
if
args
.
rm_path
is
not
None
:
reward_model
.
load_state_dict
(
state_dict
)
reward_model
.
load_state_dict
(
state_dict
)
if
args
.
strategy
!=
'colossalai_gemini'
:
if
args
.
strategy
!=
'colossalai_gemini'
:
initial_model
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
initial_model
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
reward_model
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
reward_model
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
with
strategy
.
model_init_context
():
with
strategy
.
model_init_context
():
if
args
.
model
==
'gpt2'
:
if
args
.
model
==
'gpt2'
:
actor
=
GPTActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
actor
=
GPTActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
critic
=
GPTCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
elif
args
.
model
==
'bloom'
:
elif
args
.
model
==
'bloom'
:
actor
=
BLOOMActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
actor
=
BLOOMActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
critic
=
BLOOMCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
elif
args
.
model
==
'opt'
:
elif
args
.
model
==
'opt'
:
actor
=
OPTActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
actor
=
OPTActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
critic
=
OPTCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
elif
args
.
model
==
'llama'
:
elif
args
.
model
==
'llama'
:
actor
=
LlamaActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
actor
=
LlamaActor
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
)
else
:
raise
ValueError
(
f
'Unsupported actor model "
{
args
.
model
}
"'
)
if
rm_model_name
==
'gpt2'
:
critic
=
GPTCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
elif
rm_model_name
==
'bloom'
:
critic
=
BLOOMCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
critic
=
BLOOMCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
elif
rm_model_name
==
'opt'
:
critic
=
OPTCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
elif
rm_model_name
==
'llama'
:
critic
=
LlamaCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported reward model "
{
rm_model_name
}
"'
)
if
args
.
rm_path
is
not
None
:
if
args
.
rm_path
is
not
None
:
critic
.
load_state_dict
(
state_dict
)
critic
.
load_state_dict
(
state_dict
)
del
state_dict
del
state_dict
if
args
.
strategy
!=
'colossalai_gemini'
:
if
args
.
strategy
!=
'colossalai_gemini'
:
critic
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
critic
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
actor
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
actor
.
to
(
torch
.
float16
).
to
(
torch
.
cuda
.
current_device
())
...
@@ -100,38 +121,32 @@ def main(args):
...
@@ -100,38 +121,32 @@ def main(args):
tokenizer
.
eos_token
=
'<\s>'
tokenizer
.
eos_token
=
'<\s>'
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
if
args
.
model
==
'llama'
:
if
args
.
model
==
'llama'
:
tokenizer
=
prepare_llama_tokenizer_and_embedding
(
tokenizer
,
actor
)
tokenizer
=
prepare_llama_tokenizer_and_embedding
(
tokenizer
,
actor
)
else
:
else
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
tokenizer
.
pad_token
=
tokenizer
.
eos_token
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
)
prompt_dataset
=
PromptDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
prompt_path
,
max_datasets_size
=
16384
)
prompt_dataset
=
PromptDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
prompt_path
,
max_datasets_size
=
16384
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
prompt_sampler
=
DistributedSampler
(
prompt_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
prompt_sampler
=
DistributedSampler
(
prompt_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
prompt_dataloader
=
DataLoader
(
prompt_dataset
,
prompt_dataloader
=
DataLoader
(
prompt_dataset
,
shuffle
=
(
prompt_sampler
is
None
),
sampler
=
prompt_sampler
,
batch_size
=
args
.
train_batch_size
)
shuffle
=
(
prompt_sampler
is
None
),
sampler
=
prompt_sampler
,
batch_size
=
args
.
train_batch_size
)
pretrain_dataset
=
SupervisedDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
pretrain_dataset
,
max_datasets_size
=
16384
)
pretrain_dataset
=
SupervisedDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
pretrain_dataset
,
max_datasets_size
=
16384
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
pretrain_sampler
=
DistributedSampler
(
pretrain_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
pretrain_sampler
=
DistributedSampler
(
pretrain_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
pretrain_dataloader
=
DataLoader
(
pretrain_dataset
,
pretrain_dataloader
=
DataLoader
(
pretrain_dataset
,
shuffle
=
(
pretrain_sampler
is
None
),
sampler
=
pretrain_sampler
,
batch_size
=
args
.
ptx_batch_size
,
collate_fn
=
data_collator
)
shuffle
=
(
pretrain_sampler
is
None
),
sampler
=
pretrain_sampler
,
batch_size
=
args
.
ptx_batch_size
,
collate_fn
=
data_collator
)
def
tokenize_fn
(
texts
):
def
tokenize_fn
(
texts
):
# MUST padding to max length to ensure inputs of all ranks have the same length
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
# Different length may lead to hang when using gemini, as different generation steps
batch
=
tokenizer
(
texts
,
return_tensors
=
'pt'
,
max_length
=
96
,
padding
=
'max_length'
,
truncation
=
True
)
batch
=
tokenizer
(
texts
,
return_tensors
=
'pt'
,
max_length
=
96
,
padding
=
'max_length'
,
truncation
=
True
)
return
{
k
:
v
.
to
(
torch
.
cuda
.
current_device
())
for
k
,
v
in
batch
.
items
()}
return
{
k
:
v
.
to
(
torch
.
cuda
.
current_device
())
for
k
,
v
in
batch
.
items
()}
(
actor
,
actor_optim
),
(
critic
,
critic_optim
)
=
strategy
.
prepare
((
actor
,
actor_optim
),
(
critic
,
critic_optim
))
(
actor
,
actor_optim
),
(
critic
,
critic_optim
)
=
strategy
.
prepare
(
(
actor
,
actor_optim
),
(
critic
,
critic_optim
))
# configure trainer
# configure trainer
trainer
=
PPOTrainer
(
trainer
=
PPOTrainer
(
...
@@ -177,10 +192,10 @@ if __name__ == '__main__':
...
@@ -177,10 +192,10 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--pretrain_dataset'
,
type
=
str
,
default
=
None
,
help
=
'path to the pretrained dataset'
)
parser
.
add_argument
(
'--pretrain_dataset'
,
type
=
str
,
default
=
None
,
help
=
'path to the pretrained dataset'
)
parser
.
add_argument
(
'--strategy'
,
parser
.
add_argument
(
'--strategy'
,
choices
=
[
'naive'
,
'ddp'
,
'colossalai_gemini'
,
'colossalai_zero2'
],
choices
=
[
'naive'
,
'ddp'
,
'colossalai_gemini'
,
'colossalai_zero2'
],
default
=
'naive'
,
default
=
'naive'
,
help
=
'strategy to use'
)
help
=
'strategy to use'
)
parser
.
add_argument
(
'--model'
,
default
=
'gpt2'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
,
'llama'
])
parser
.
add_argument
(
'--model'
,
default
=
'gpt2'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
,
'llama'
])
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--rm_model'
,
default
=
None
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
,
'llama'
])
parser
.
add_argument
(
'--rm_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--rm_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--rm_pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--rm_pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_path'
,
type
=
str
,
default
=
'actor_checkpoint_prompts'
)
parser
.
add_argument
(
'--save_path'
,
type
=
str
,
default
=
'actor_checkpoint_prompts'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment