Unverified Commit c474fda2 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[chatgpt] fix ppo training hanging problem with gemini (#3162)

* [chatgpt] fix generation early stopping

* [chatgpt] fix train prompts example
parent 6ae8ed04
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
try: try:
...@@ -27,6 +28,14 @@ def prepare_logits_processor(top_k: Optional[int] = None, ...@@ -27,6 +28,14 @@ def prepare_logits_processor(top_k: Optional[int] = None,
return processor_list return processor_list
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
if dist.is_initialized() and dist.get_world_size() > 1:
# consider DP
unfinished_sequences = unfinished_sequences.clone()
dist.all_reduce(unfinished_sequences)
return unfinished_sequences.max() == 0
def sample(model: nn.Module, def sample(model: nn.Module,
input_ids: torch.Tensor, input_ids: torch.Tensor,
max_length: int, max_length: int,
...@@ -74,7 +83,7 @@ def sample(model: nn.Module, ...@@ -74,7 +83,7 @@ def sample(model: nn.Module,
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True # stop when each sentence is finished if early_stopping=True
if early_stopping and unfinished_sequences.max() == 0: if early_stopping and _is_sequence_finished(unfinished_sequences):
break break
return input_ids return input_ids
......
...@@ -46,7 +46,6 @@ def main(args): ...@@ -46,7 +46,6 @@ def main(args):
initial_model = deepcopy(actor) initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device()) reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6) actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
...@@ -70,7 +69,9 @@ def main(args): ...@@ -70,7 +69,9 @@ def main(args):
dataset = pd.read_csv(args.prompt_path)['prompt'] dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts): def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) # MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()} return {k: v.cuda() for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
...@@ -101,7 +102,7 @@ def main(args): ...@@ -101,7 +102,7 @@ def main(args):
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
# save model checkpoint after fitting # save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True) strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment