Unverified Commit 1e58d31b authored by ver217's avatar ver217 Committed by GitHub
Browse files

[chatgpt] fix trainer generate kwargs (#3166)

parent c474fda2
......@@ -63,6 +63,7 @@ class PPOTrainer(Trainer):
**generate_kwargs) -> None:
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
self.actor = actor
......@@ -73,7 +74,6 @@ class PPOTrainer(Trainer):
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self._set_default_generate_kwargs(generate_kwargs, actor)
def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
......@@ -102,11 +102,15 @@ class PPOTrainer(Trainer):
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
origin_model = self.strategy._unwrap_actor(actor)
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
return new_kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment