Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2e16f842
Unverified
Commit
2e16f842
authored
Feb 22, 2023
by
BlueRum
Committed by
GitHub
Feb 22, 2023
Browse files
[chatgpt]support opt & gpt for rm training (#2876)
parent
c52edcf0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
17 deletions
+48
-17
applications/ChatGPT/chatgpt/nn/bloom_rm.py
applications/ChatGPT/chatgpt/nn/bloom_rm.py
+0
-1
applications/ChatGPT/chatgpt/nn/gpt_rm.py
applications/ChatGPT/chatgpt/nn/gpt_rm.py
+7
-2
applications/ChatGPT/chatgpt/nn/opt_rm.py
applications/ChatGPT/chatgpt/nn/opt_rm.py
+7
-3
applications/ChatGPT/examples/train_reward_model.py
applications/ChatGPT/examples/train_reward_model.py
+31
-10
applications/ChatGPT/examples/train_rm.sh
applications/ChatGPT/examples/train_rm.sh
+3
-1
No files found.
applications/ChatGPT/chatgpt/nn/bloom_rm.py
View file @
2e16f842
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
BloomConfig
,
BloomForCausalLM
,
BloomModel
from
transformers
import
BloomConfig
,
BloomForCausalLM
,
BloomModel
...
...
applications/ChatGPT/chatgpt/nn/gpt_rm.py
View file @
2e16f842
...
@@ -15,12 +15,16 @@ class GPTRM(RewardModel):
...
@@ -15,12 +15,16 @@ class GPTRM(RewardModel):
pretrained (str): Pretrained model name or path.
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
if
pretrained
is
not
None
:
model
=
GPT2Model
.
from_pretrained
(
pretrained
)
model
=
GPT2Model
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
elif
config
is
not
None
:
...
@@ -29,5 +33,6 @@ class GPTRM(RewardModel):
...
@@ -29,5 +33,6 @@ class GPTRM(RewardModel):
model
=
GPT2Model
(
GPT2Config
())
model
=
GPT2Model
(
GPT2Config
())
if
checkpoint
:
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
super
().
__init__
(
model
,
value_head
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/nn/opt_rm.py
View file @
2e16f842
from
typing
import
Optional
from
typing
import
Optional
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.models.opt.configuration_opt
import
OPTConfig
from
transformers
import
OPTConfig
,
OPTModel
from
transformers.models.opt.modeling_opt
import
OPTModel
from
.reward_model
import
RewardModel
from
.reward_model
import
RewardModel
...
@@ -14,6 +13,7 @@ class OPTRM(RewardModel):
...
@@ -14,6 +13,7 @@ class OPTRM(RewardModel):
Args:
Args:
pretrained (str): Pretrained model name or path.
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
lora_train_bias (str): LoRA bias training mode.
"""
"""
...
@@ -21,6 +21,7 @@ class OPTRM(RewardModel):
...
@@ -21,6 +21,7 @@ class OPTRM(RewardModel):
def
__init__
(
self
,
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
if
pretrained
is
not
None
:
...
@@ -29,5 +30,8 @@ class OPTRM(RewardModel):
...
@@ -29,5 +30,8 @@ class OPTRM(RewardModel):
model
=
OPTModel
(
config
)
model
=
OPTModel
(
config
)
else
:
else
:
model
=
OPTModel
(
OPTConfig
())
model
=
OPTModel
(
OPTConfig
())
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
word_embed_proj_dim
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/examples/train_reward_model.py
View file @
2e16f842
...
@@ -3,12 +3,13 @@ import argparse
...
@@ -3,12 +3,13 @@ import argparse
import
loralib
as
lora
import
loralib
as
lora
import
torch
import
torch
from
chatgpt.dataset
import
RewardDataset
from
chatgpt.dataset
import
RewardDataset
from
chatgpt.nn
import
BLOOMRM
from
chatgpt.nn
import
BLOOMRM
,
GPTRM
,
OPTRM
from
chatgpt.trainer
import
RewardModelTrainer
from
chatgpt.trainer
import
RewardModelTrainer
from
chatgpt.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
chatgpt.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
transformers
import
BloomTokenizerFast
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
...
@@ -27,11 +28,30 @@ def train(args):
...
@@ -27,11 +28,30 @@ def train(args):
raise
ValueError
(
f
'Unsupported strategy "
{
args
.
strategy
}
"'
)
raise
ValueError
(
f
'Unsupported strategy "
{
args
.
strategy
}
"'
)
# configure model
# configure model
tokenizer
=
BloomTokenizerFast
.
from_pretrained
(
args
.
pretrain
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
with
strategy
.
model_init_context
():
with
strategy
.
model_init_context
():
model
=
BLOOMRM
(
pretrained
=
args
.
pretrain
).
cuda
()
if
args
.
model
==
'bloom'
:
max_len
=
1024
model
=
BLOOMRM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
elif
args
.
model
==
'opt'
:
model
=
OPTRM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
elif
args
.
model
==
'gpt2'
:
model
=
GPTRM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
# configure tokenizer
if
args
.
model
==
'gpt2'
:
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
elif
args
.
model
==
'bloom'
:
tokenizer
=
BloomTokenizerFast
.
from_pretrained
(
args
.
pretrain
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
elif
args
.
model
==
'opt'
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-350m"
)
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
max_len
=
512
# configure optimizer
# configure optimizer
if
args
.
strategy
.
startswith
(
'colossalai'
):
if
args
.
strategy
.
startswith
(
'colossalai'
):
...
@@ -58,10 +78,10 @@ def train(args):
...
@@ -58,10 +78,10 @@ def train(args):
trainer
.
fit
(
use_lora
=
args
.
lora_rank
)
trainer
.
fit
(
use_lora
=
args
.
lora_rank
)
if
args
.
lora_rank
>
0
:
# save model checkpoint after fitting on only rank0
torch
.
save
({
'model_state_dict'
:
lora
.
lora_state_dict
(
trainer
.
model
)},
args
.
save_path
)
strategy
.
save_model
(
model
,
'rm_checkpoint.pt'
,
only_rank0
=
True
)
else
:
# save optimizer checkpoint on all ranks
torch
.
save
(
trainer
.
model
,
args
.
save_path
)
strategy
.
save_optimizer
(
optim
,
'rm_optim_checkpoint_%d.pt'
%
(
torch
.
cuda
.
current_device
()),
only_rank0
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -69,6 +89,7 @@ if __name__ == '__main__':
...
@@ -69,6 +89,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--strategy'
,
parser
.
add_argument
(
'--strategy'
,
choices
=
[
'naive'
,
'ddp'
,
'colossalai_gemini'
,
'colossalai_zero2'
],
choices
=
[
'naive'
,
'ddp'
,
'colossalai_gemini'
,
'colossalai_zero2'
],
default
=
'naive'
)
default
=
'naive'
)
parser
.
add_argument
(
'--model'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
],
default
=
'bloom'
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'Dahoas/rm-static'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'Dahoas/rm-static'
)
parser
.
add_argument
(
'--save_path'
,
type
=
str
,
default
=
'rm_ckpt.pth'
)
parser
.
add_argument
(
'--save_path'
,
type
=
str
,
default
=
'rm_ckpt.pth'
)
...
...
applications/ChatGPT/examples/train_rm.sh
View file @
2e16f842
...
@@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
...
@@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2
set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun
--standalone
--nproc_per_node
=
2 train_reward_model.py
--pretrain
'/data2/users/lczht/bloom-560m'
--strategy
colossalai_zero2
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
torchrun
--standalone
--nproc_per_node
=
2 train_reward_model.py
--model
'gpt2'
--strategy
colossalai_zero2
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment