"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "85d8d02d70d9371cf36a4ea004dd5e94c2f4f62e"
Unverified Commit 287d6049 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

[chatgpt] Add saving ckpt callback for PPO (#2880)



* add checkpoint callback for chatgpt

* add save ckpt callbacks for ppo

---------
Co-authored-by: default avatarFazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
parent e5887034
from .base import Callback from .base import Callback
from .performance_evaluator import PerformanceEvaluator from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
__all__ = ['Callback', 'PerformanceEvaluator'] __all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
import os
import torch.distributed as dist
from chatgpt.trainer.strategies import ColossalAIStrategy, Strategy
from chatgpt.trainer.utils import is_rank_0
from torch import nn
from torch.optim import Optimizer
from .base import Callback
class SaveCheckpoint(Callback):
"""
The callback for saving checkpoint for chatgpt.
Only support saving actor and critic model.
A typical architecture of the saved checkpoint would be:
- checkpoint
- episode_x
- actor.pt
- actor-optim-rank-0.pt
- actor-optim-rank-1.pt
- critic.pt
- critic-optim-rank-0.pt
- critic-optim-rank-1.pt
- ...
Args:
path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
interval(int): the interval episode of saving checkpoint
strategy(Strategy): the strategy used to train
actor(nn.Module): the actor model
critic(nn.Module): the critic model
actor_optim(Optimizer): the optimizer of actor
critic_optim(Optimizer): the optimizer of critic
"""
def __init__(self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None) -> None:
super().__init__()
self.path = os.path.join(path, 'checkpoint')
self.interval = interval
self.strategy = strategy
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
base_path = os.path.join(self.path, f'episode_{episode}')
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
model_path = os.path.join(base_path, f'{model}.pt')
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
if self.model_dict[model][1] is None:
continue
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
...@@ -4,6 +4,7 @@ from copy import deepcopy ...@@ -4,6 +4,7 @@ from copy import deepcopy
import torch import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import SaveCheckpoint
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast from transformers import AutoTokenizer, BloomTokenizerFast
...@@ -71,26 +72,38 @@ def main(args): ...@@ -71,26 +72,38 @@ def main(args):
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model) (actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
callbacks = []
if args.save_ckpt_path:
ckpt_callback = SaveCheckpoint(
args.save_ckpt_path,
args.save_ckpt_interval,
strategy,
actor,
critic,
actor_optim,
critic_optim,
)
callbacks.append(ckpt_callback)
# configure trainer # configure trainer
trainer = PPOTrainer(
strategy, trainer = PPOTrainer(strategy,
actor, actor,
critic, critic,
reward_model, reward_model,
initial_model, initial_model,
actor_optim, actor_optim,
critic_optim, critic_optim,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size, tokenizer=preprocess_batch,
tokenizer=preprocess_batch, max_length=128,
max_length=128, do_sample=True,
do_sample=True, temperature=1.0,
temperature=1.0, top_k=50,
top_k=50, pad_token_id=tokenizer.pad_token_id,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id, callbacks=callbacks)
)
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
trainer.fit(random_prompts, trainer.fit(random_prompts,
...@@ -120,5 +133,10 @@ if __name__ == '__main__': ...@@ -120,5 +133,10 @@ if __name__ == '__main__':
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--save_ckpt_path',
type=str,
default=None,
help="path to save checkpoint, None means not to save")
parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment