Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9998d5ef
Unverified
Commit
9998d5ef
authored
Mar 22, 2023
by
Yuanchen
Committed by
GitHub
Mar 22, 2023
Browse files
[chatgpt]add reward model code for deberta (#3199)
Co-authored-by:
Yuanchen Xu
<
yuanchen.xu00@gmail.com
>
parent
1e1b9d2f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
93 additions
and
4 deletions
+93
-4
applications/ChatGPT/chatgpt/models/deberta/__init__.py
applications/ChatGPT/chatgpt/models/deberta/__init__.py
+4
-0
applications/ChatGPT/chatgpt/models/deberta/deberta_critic.py
...ications/ChatGPT/chatgpt/models/deberta/deberta_critic.py
+36
-0
applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
+37
-0
applications/ChatGPT/examples/requirements.txt
applications/ChatGPT/examples/requirements.txt
+1
-0
applications/ChatGPT/examples/test_ci.sh
applications/ChatGPT/examples/test_ci.sh
+6
-0
applications/ChatGPT/examples/train_reward_model.py
applications/ChatGPT/examples/train_reward_model.py
+7
-2
applications/ChatGPT/examples/train_rm.sh
applications/ChatGPT/examples/train_rm.sh
+2
-2
No files found.
applications/ChatGPT/chatgpt/models/deberta/__init__.py
0 → 100644
View file @
9998d5ef
from
.deberta_critic
import
DebertaCritic
from
.deberta_rm
import
DebertaRM
__all__
=
[
'DebertaCritic'
,
'DebertaRM'
]
applications/ChatGPT/chatgpt/models/deberta/deberta_critic.py
0 → 100644
View file @
9998d5ef
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers
import
DebertaV2Config
,
DebertaV2Model
from
..base
import
Critic
class
DebertaCritic
(
Critic
):
"""
Deberta Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
DebertaV2Config
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
DebertaV2Model
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
DebertaV2Model
(
config
)
else
:
model
=
DebertaV2Model
(
DebertaV2Config
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
0 → 100644
View file @
9998d5ef
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers
import
DebertaV2Config
,
DebertaV2Model
from
..base
import
RewardModel
class
DebertaRM
(
RewardModel
):
"""
Deberta Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
str
=
None
,
config
:
Optional
[
DebertaV2Config
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
DebertaV2Model
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
DebertaV2Model
(
config
)
else
:
model
=
DebertaV2Model
(
DebertaV2Config
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
value_head
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
1
/
(
model
.
config
.
hidden_size
+
1
))
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/examples/requirements.txt
View file @
9998d5ef
pandas>=1.4.1
sentencepiece
applications/ChatGPT/examples/test_ci.sh
View file @
9998d5ef
...
...
@@ -88,4 +88,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset
'Anthropic/hh-rlhf'
--subset
'harmless-base'
\
--test
True
--lora_rank
4
torchrun
--standalone
--nproc_per_node
=
2
${
BASE
}
/train_reward_model.py
\
--pretrain
'microsoft/deberta-v3-large'
--model
'deberta'
\
--strategy
colossalai_zero2
--loss_fn
'log_sig'
\
--dataset
'Anthropic/hh-rlhf'
--subset
'harmless-base'
\
--test
True
--lora_rank
4
rm
-rf
${
BASE
}
/rm_ckpt.pt
applications/ChatGPT/examples/train_reward_model.py
View file @
9998d5ef
...
...
@@ -8,12 +8,13 @@ from chatgpt.models.base import RewardModel
from
chatgpt.models.bloom
import
BLOOMRM
from
chatgpt.models.gpt
import
GPTRM
from
chatgpt.models.opt
import
OPTRM
from
chatgpt.models.deberta
import
DebertaRM
from
chatgpt.trainer
import
RewardModelTrainer
from
chatgpt.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
datasets
import
load_dataset
from
random
import
randint
from
torch.optim
import
Adam
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
DebertaV2Tokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
colossalai.nn.optimizer
import
HybridAdam
...
...
@@ -39,6 +40,8 @@ def train(args):
model
=
OPTRM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
to
(
torch
.
cuda
.
current_device
())
elif
args
.
model
==
'gpt2'
:
model
=
GPTRM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
to
(
torch
.
cuda
.
current_device
())
elif
args
.
model
==
'deberta'
:
model
=
DebertaRM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
to
(
torch
.
cuda
.
current_device
())
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
...
...
@@ -54,6 +57,8 @@ def train(args):
tokenizer
=
BloomTokenizerFast
.
from_pretrained
(
'bigscience/bloom-560m'
)
elif
args
.
model
==
'opt'
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-350m"
)
elif
args
.
model
==
'deberta'
:
tokenizer
=
DebertaV2Tokenizer
.
from_pretrained
(
'microsoft/deberta-v3-large'
)
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
max_len
=
args
.
max_len
...
...
@@ -119,7 +124,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--strategy'
,
choices
=
[
'naive'
,
'ddp'
,
'colossalai_gemini'
,
'colossalai_zero2'
],
default
=
'naive'
)
parser
.
add_argument
(
'--model'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
],
default
=
'bloom'
)
parser
.
add_argument
(
'--model'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
,
'deberta'
],
default
=
'bloom'
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--need_optim_ckpt'
,
type
=
bool
,
default
=
False
)
...
...
applications/ChatGPT/examples/train_rm.sh
View file @
9998d5ef
set_n_least_used_CUDA_VISIBLE_DEVICES 1
python train_reward_model.py
--pretrain
'
/home/lczht/data2/bloom-560m
'
\
--model
'
bloom
'
\
python train_reward_model.py
--pretrain
'
microsoft/deberta-v3-large
'
\
--model
'
deberta
'
\
--strategy
naive
\
--loss_fn
'log_exp'
\
--save_path
'rmstatic.pt'
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment