Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f5ca0397
Unverified
Commit
f5ca0397
authored
Mar 03, 2023
by
BlueRum
Committed by
GitHub
Mar 03, 2023
Browse files
[chatgpt] fix lora gemini conflict in RM training (#2984)
* fix lora bug * polish * fix lora gemini
parent
19ad49fb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
10 deletions
+3
-10
applications/ChatGPT/chatgpt/nn/reward_model.py
applications/ChatGPT/chatgpt/nn/reward_model.py
+2
-2
applications/ChatGPT/chatgpt/trainer/rm.py
applications/ChatGPT/chatgpt/trainer/rm.py
+1
-6
applications/ChatGPT/examples/train_reward_model.py
applications/ChatGPT/examples/train_reward_model.py
+0
-2
No files found.
applications/ChatGPT/chatgpt/nn/reward_model.py
View file @
f5ca0397
...
@@ -24,14 +24,14 @@ class RewardModel(LoRAModule):
...
@@ -24,14 +24,14 @@ class RewardModel(LoRAModule):
lora_train_bias
:
str
=
'none'
)
->
None
:
lora_train_bias
:
str
=
'none'
)
->
None
:
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
self
.
model
=
model
self
.
model
=
model
self
.
convert_to_lora
()
if
value_head
is
not
None
:
if
value_head
is
not
None
:
if
value_head
.
out_features
!=
1
:
if
value_head
.
out_features
!=
1
:
raise
ValueError
(
"The value head of reward model's output dim should be 1!"
)
raise
ValueError
(
"The value head of reward model's output dim should be 1!"
)
self
.
value_head
=
value_head
self
.
value_head
=
value_head
else
:
else
:
self
.
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
self
.
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
self
.
convert_to_lora
()
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
outputs
=
self
.
model
(
sequences
,
attention_mask
=
attention_mask
)
outputs
=
self
.
model
(
sequences
,
attention_mask
=
attention_mask
)
...
...
applications/ChatGPT/chatgpt/trainer/rm.py
View file @
f5ca0397
...
@@ -56,12 +56,7 @@ class RewardModelTrainer(ABC):
...
@@ -56,12 +56,7 @@ class RewardModelTrainer(ABC):
desc
=
'Train step of epoch %d'
%
epoch
,
desc
=
'Train step of epoch %d'
%
epoch
,
disable
=
not
is_rank_0
())
disable
=
not
is_rank_0
())
# train
# train
if
use_lora
>
0
:
self
.
model
.
train
()
print
(
"Using Lora"
)
lora
.
mark_only_lora_as_trainable
(
self
.
model
.
model
)
else
:
self
.
model
.
train
()
for
chosen_ids
,
c_mask
,
reject_ids
,
r_mask
in
self
.
train_dataloader
:
for
chosen_ids
,
c_mask
,
reject_ids
,
r_mask
in
self
.
train_dataloader
:
chosen_ids
=
chosen_ids
.
squeeze
(
1
).
cuda
()
chosen_ids
=
chosen_ids
.
squeeze
(
1
).
cuda
()
c_mask
=
c_mask
.
squeeze
(
1
).
cuda
()
c_mask
=
c_mask
.
squeeze
(
1
).
cuda
()
...
...
applications/ChatGPT/examples/train_reward_model.py
View file @
f5ca0397
...
@@ -66,8 +66,6 @@ def train(args):
...
@@ -66,8 +66,6 @@ def train(args):
train_dataset
=
RewardDataset
(
train_data
,
tokenizer
,
max_len
)
train_dataset
=
RewardDataset
(
train_data
,
tokenizer
,
max_len
)
eval_dataset
=
RewardDataset
(
eval_data
,
tokenizer
,
max_len
)
eval_dataset
=
RewardDataset
(
eval_data
,
tokenizer
,
max_len
)
# batch_size here is expected to be C(k,2), k means # response of each prompt
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
trainer
=
RewardModelTrainer
(
model
=
model
,
trainer
=
RewardModelTrainer
(
model
=
model
,
strategy
=
strategy
,
strategy
=
strategy
,
optim
=
optim
,
optim
=
optim
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment