"vscode:/vscode.git/clone" did not exist on "660eed912495eb0f9473ba53dd191e4b44e1d31f"
Unverified Commit b9231390 authored by Yuanchen's avatar Yuanchen Committed by GitHub
Browse files

fix save_model indent error in ppo trainer (#3450)


Co-authored-by: default avatarYuanchen Xu <yuanchen.xu00@gmail.com>
parent ffcdbf0f
...@@ -117,6 +117,9 @@ class PPOTrainer(Trainer): ...@@ -117,6 +117,9 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()} return {'reward': experience.reward.mean().item()}
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor) origin_model = strategy._unwrap_actor(actor)
...@@ -129,7 +132,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto ...@@ -129,7 +132,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
return new_kwargs return new_kwargs
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment