Unverified Commit 73afb635 authored by Dr-Corgi's avatar Dr-Corgi Committed by GitHub
Browse files

[chat]fix save_model(#3377)

The function save_model should be a part of PPOTrainer.
parent 57a3c4db
...@@ -116,6 +116,9 @@ class PPOTrainer(Trainer): ...@@ -116,6 +116,9 @@ class PPOTrainer(Trainer):
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
return {'reward': experience.reward.mean().item()} return {'reward': experience.reward.mean().item()}
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment