Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1e58d31b
Unverified
Commit
1e58d31b
authored
Mar 17, 2023
by
ver217
Committed by
GitHub
Mar 17, 2023
Browse files
[chatgpt] fix trainer generate kwargs (#3166)
parent
c474fda2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
8 deletions
+12
-8
applications/ChatGPT/chatgpt/trainer/ppo.py
applications/ChatGPT/chatgpt/trainer/ppo.py
+12
-8
No files found.
applications/ChatGPT/chatgpt/trainer/ppo.py
View file @
1e58d31b
...
...
@@ -63,6 +63,7 @@ class PPOTrainer(Trainer):
**
generate_kwargs
)
->
None
:
experience_maker
=
NaiveExperienceMaker
(
actor
,
critic
,
reward_model
,
initial_model
,
kl_coef
)
replay_buffer
=
NaiveReplayBuffer
(
train_batch_size
,
buffer_limit
,
buffer_cpu_offload
)
generate_kwargs
=
_set_default_generate_kwargs
(
strategy
,
generate_kwargs
,
actor
)
super
().
__init__
(
strategy
,
experience_maker
,
replay_buffer
,
experience_batch_size
,
max_epochs
,
tokenizer
,
sample_replay_buffer
,
dataloader_pin_memory
,
callbacks
,
**
generate_kwargs
)
self
.
actor
=
actor
...
...
@@ -73,7 +74,6 @@ class PPOTrainer(Trainer):
self
.
actor_optim
=
actor_optim
self
.
critic_optim
=
critic_optim
self
.
_set_default_generate_kwargs
(
generate_kwargs
,
actor
)
def
training_step
(
self
,
experience
:
Experience
)
->
Dict
[
str
,
float
]:
self
.
actor
.
train
()
...
...
@@ -102,11 +102,15 @@ class PPOTrainer(Trainer):
return
{
'actor_loss'
:
actor_loss
.
item
(),
'critic_loss'
:
critic_loss
.
item
()}
def
_set_default_generate_kwargs
(
self
,
generate_kwargs
:
dict
,
actor
:
Actor
)
->
None
:
origin_model
=
self
.
strategy
.
_unwrap_actor
(
actor
)
# use huggingface models method directly
if
'prepare_inputs_fn'
not
in
generate_kwargs
and
hasattr
(
origin_model
,
'prepare_inputs_for_generation'
):
generate_kwargs
[
'prepare_inputs_fn'
]
=
origin_model
.
prepare_inputs_for_generation
if
'update_model_kwargs_fn'
not
in
generate_kwargs
:
generate_kwargs
[
'update_model_kwargs_fn'
]
=
update_model_kwargs_fn
def
_set_default_generate_kwargs
(
strategy
:
Strategy
,
generate_kwargs
:
dict
,
actor
:
Actor
)
->
None
:
origin_model
=
strategy
.
_unwrap_actor
(
actor
)
new_kwargs
=
{
**
generate_kwargs
}
# use huggingface models method directly
if
'prepare_inputs_fn'
not
in
generate_kwargs
and
hasattr
(
origin_model
,
'prepare_inputs_for_generation'
):
new_kwargs
[
'prepare_inputs_fn'
]
=
origin_model
.
prepare_inputs_for_generation
if
'update_model_kwargs_fn'
not
in
generate_kwargs
:
new_kwargs
[
'update_model_kwargs_fn'
]
=
update_model_kwargs_fn
return
new_kwargs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment