Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c9e27f0d
Unverified
Commit
c9e27f0d
authored
Mar 02, 2023
by
BlueRum
Committed by
GitHub
Mar 02, 2023
Browse files
[chatgpt]fix lora bug (#2974)
* fix lora bug * polish
parent
82149e9d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
6 deletions
+8
-6
applications/ChatGPT/chatgpt/nn/reward_model.py
applications/ChatGPT/chatgpt/nn/reward_model.py
+2
-2
applications/ChatGPT/chatgpt/trainer/rm.py
applications/ChatGPT/chatgpt/trainer/rm.py
+3
-1
applications/ChatGPT/examples/train_reward_model.py
applications/ChatGPT/examples/train_reward_model.py
+3
-3
No files found.
applications/ChatGPT/chatgpt/nn/reward_model.py
View file @
c9e27f0d
...
@@ -23,7 +23,7 @@ class RewardModel(LoRAModule):
...
@@ -23,7 +23,7 @@ class RewardModel(LoRAModule):
lora_rank
:
int
=
0
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
lora_train_bias
:
str
=
'none'
)
->
None
:
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
self
.
b
od
y
=
model
self
.
m
od
el
=
model
if
value_head
is
not
None
:
if
value_head
is
not
None
:
if
value_head
.
out_features
!=
1
:
if
value_head
.
out_features
!=
1
:
raise
ValueError
(
"The value head of reward model's output dim should be 1!"
)
raise
ValueError
(
"The value head of reward model's output dim should be 1!"
)
...
@@ -34,7 +34,7 @@ class RewardModel(LoRAModule):
...
@@ -34,7 +34,7 @@ class RewardModel(LoRAModule):
self
.
convert_to_lora
()
self
.
convert_to_lora
()
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
outputs
=
self
.
b
od
y
(
sequences
,
attention_mask
=
attention_mask
)
outputs
=
self
.
m
od
el
(
sequences
,
attention_mask
=
attention_mask
)
last_hidden_states
=
outputs
[
'last_hidden_state'
]
last_hidden_states
=
outputs
[
'last_hidden_state'
]
values
=
self
.
value_head
(
last_hidden_states
)[:,
:
-
1
]
values
=
self
.
value_head
(
last_hidden_states
)[:,
:
-
1
]
value
=
values
.
mean
(
dim
=
1
).
squeeze
(
1
)
# ensure shape is (B)
value
=
values
.
mean
(
dim
=
1
).
squeeze
(
1
)
# ensure shape is (B)
...
...
applications/ChatGPT/chatgpt/trainer/rm.py
View file @
c9e27f0d
...
@@ -44,6 +44,8 @@ class RewardModelTrainer(ABC):
...
@@ -44,6 +44,8 @@ class RewardModelTrainer(ABC):
self
.
eval_dataloader
=
DataLoader
(
eval_dataset
,
batch_size
=
batch_size
)
self
.
eval_dataloader
=
DataLoader
(
eval_dataset
,
batch_size
=
batch_size
)
self
.
model
=
strategy
.
setup_model
(
model
)
self
.
model
=
strategy
.
setup_model
(
model
)
if
"DDP"
in
str
(
self
.
strategy
):
self
.
model
=
self
.
model
.
module
self
.
loss_fn
=
PairWiseLoss
()
self
.
loss_fn
=
PairWiseLoss
()
self
.
optimizer
=
strategy
.
setup_optimizer
(
optim
,
self
.
model
)
self
.
optimizer
=
strategy
.
setup_optimizer
(
optim
,
self
.
model
)
...
@@ -56,7 +58,7 @@ class RewardModelTrainer(ABC):
...
@@ -56,7 +58,7 @@ class RewardModelTrainer(ABC):
# train
# train
if
use_lora
>
0
:
if
use_lora
>
0
:
print
(
"Using Lora"
)
print
(
"Using Lora"
)
lora
.
mark_only_lora_as_trainable
(
self
.
model
.
b
od
y
)
lora
.
mark_only_lora_as_trainable
(
self
.
model
.
m
od
el
)
else
:
else
:
self
.
model
.
train
()
self
.
model
.
train
()
...
...
applications/ChatGPT/examples/train_reward_model.py
View file @
c9e27f0d
...
@@ -61,8 +61,8 @@ def train(args):
...
@@ -61,8 +61,8 @@ def train(args):
# prepare for data and dataset
# prepare for data and dataset
data
=
load_dataset
(
args
.
dataset
)
data
=
load_dataset
(
args
.
dataset
)
train_data
=
data
[
"train"
]
.
select
(
range
(
100
))
train_data
=
data
[
"train"
]
eval_data
=
data
[
'test'
]
.
select
(
range
(
5
))
eval_data
=
data
[
'test'
]
train_dataset
=
RewardDataset
(
train_data
,
tokenizer
,
max_len
)
train_dataset
=
RewardDataset
(
train_data
,
tokenizer
,
max_len
)
eval_dataset
=
RewardDataset
(
eval_data
,
tokenizer
,
max_len
)
eval_dataset
=
RewardDataset
(
eval_data
,
tokenizer
,
max_len
)
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'Dahoas/rm-static'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'Dahoas/rm-static'
)
parser
.
add_argument
(
'--save_path'
,
type
=
str
,
default
=
'rm_ckpt.pth'
)
parser
.
add_argument
(
'--save_path'
,
type
=
str
,
default
=
'rm_ckpt.pth'
)
parser
.
add_argument
(
'--max_epochs'
,
type
=
int
,
default
=
1
0
)
parser
.
add_argument
(
'--max_epochs'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--lora_rank'
,
type
=
int
,
default
=
0
,
help
=
"low-rank adaptation matrices rank"
)
parser
.
add_argument
(
'--lora_rank'
,
type
=
int
,
default
=
0
,
help
=
"low-rank adaptation matrices rank"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment