Unverified Commit 34ca324b authored by BlueRum's avatar BlueRum Committed by GitHub
Browse files

[chatgpt] Support saving ckpt in examples (#2846)

* [chatgpt]fix train_rm bug with lora

* [chatgpt]support colossalai strategy to train rm

* fix pre-commit

* fix pre-commit 2

* [chatgpt]fix rm eval typo

* fix rm eval

* fix pre commit

* add support of saving ckpt in examples

* fix single-gpu save
parent 59791431
...@@ -97,6 +97,13 @@ def main(args): ...@@ -97,6 +97,13 @@ def main(args):
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
from copy import deepcopy from copy import deepcopy
import pandas as pd import pandas as pd
import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
...@@ -95,6 +96,12 @@ def main(args): ...@@ -95,6 +96,12 @@ def main(args):
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment