Unverified Commit 9c0943ec authored by ver217's avatar ver217 Committed by GitHub
Browse files

[chatgpt] optimize generation kwargs (#2717)

* [chatgpt] ppo trainer use default generate args

* [chatgpt] example remove generation preparing fn

* [chatgpt] benchmark remove generation preparing fn

* [chatgpt] fix ci
parent 21d6a48f
...@@ -34,6 +34,7 @@ jobs: ...@@ -34,6 +34,7 @@ jobs:
- name: Execute Examples - name: Execute Examples
run: | run: |
cd applications/ChatGPT
./examples/test_ci.sh ./examples/test_ci.sh
env: env:
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
......
...@@ -35,6 +35,7 @@ jobs: ...@@ -35,6 +35,7 @@ jobs:
- name: Execute Unit Testing - name: Execute Unit Testing
run: | run: |
cd applications/ChatGPT
pytest tests/ pytest tests/
env: env:
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import GPTActor, GPTCritic, RewardModel from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
...@@ -151,8 +150,6 @@ def main(args): ...@@ -151,8 +150,6 @@ def main(args):
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=gpt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import OPTActor, OPTCritic, RewardModel from chatgpt.nn import OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
...@@ -144,8 +143,6 @@ def main(args): ...@@ -144,8 +143,6 @@ def main(args):
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=opt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
......
...@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn import torch.nn as nn
from chatgpt.experience_maker import Experience, NaiveExperienceMaker from chatgpt.experience_maker import Experience, NaiveExperienceMaker
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
from chatgpt.nn.generation_utils import update_model_kwargs_fn
from chatgpt.replay_buffer import NaiveReplayBuffer from chatgpt.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -59,6 +60,7 @@ class PPOTrainer(Trainer): ...@@ -59,6 +60,7 @@ class PPOTrainer(Trainer):
dataloader_pin_memory: bool = True, dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [], callbacks: List[Callback] = [],
**generate_kwargs) -> None: **generate_kwargs) -> None:
self._set_default_generate_kwargs(generate_kwargs, actor)
actor = Actor(strategy.setup_model(actor.model)) actor = Actor(strategy.setup_model(actor.model))
critic = strategy.setup_model(critic) critic = strategy.setup_model(critic)
reward_model = strategy.setup_model(reward_model) reward_model = strategy.setup_model(reward_model)
...@@ -102,3 +104,11 @@ class PPOTrainer(Trainer): ...@@ -102,3 +104,11 @@ class PPOTrainer(Trainer):
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
...@@ -3,12 +3,6 @@ from copy import deepcopy ...@@ -3,12 +3,6 @@ from copy import deepcopy
import torch import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import (
bloom_prepare_inputs_fn,
gpt_prepare_inputs_fn,
opt_prepare_inputs_fn,
update_model_kwargs_fn,
)
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
...@@ -66,36 +60,33 @@ def main(args): ...@@ -66,36 +60,33 @@ def main(args):
if args.model == 'gpt2': if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
prepare_inputs_fn = gpt_prepare_inputs_fn
elif args.model == 'bloom': elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
prepare_inputs_fn = bloom_prepare_inputs_fn
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
prepare_inputs_fn = opt_prepare_inputs_fn
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
# configure trainer # configure trainer
trainer = PPOTrainer(strategy, trainer = PPOTrainer(
actor, strategy,
critic, actor,
reward_model, critic,
initial_model, reward_model,
actor_optim, initial_model,
critic_optim, actor_optim,
max_epochs=args.max_epochs, critic_optim,
train_batch_size=args.train_batch_size, max_epochs=args.max_epochs,
tokenizer=preprocess_batch, train_batch_size=args.train_batch_size,
max_length=128, tokenizer=preprocess_batch,
do_sample=True, max_length=128,
temperature=1.0, do_sample=True,
top_k=50, temperature=1.0,
pad_token_id=tokenizer.pad_token_id, top_k=50,
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
prepare_inputs_fn=prepare_inputs_fn, eos_token_id=tokenizer.eos_token_id,
update_model_kwargs_fn=update_model_kwargs_fn) )
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
trainer.fit(random_prompts, trainer.fit(random_prompts,
......
...@@ -3,7 +3,6 @@ from copy import deepcopy ...@@ -3,7 +3,6 @@ from copy import deepcopy
import pandas as pd import pandas as pd
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
...@@ -70,24 +69,24 @@ def main(args): ...@@ -70,24 +69,24 @@ def main(args):
return {k: v.cuda() for k, v in batch.items()} return {k: v.cuda() for k, v in batch.items()}
# configure trainer # configure trainer
trainer = PPOTrainer(strategy, trainer = PPOTrainer(
actor, strategy,
critic, actor,
reward_model, critic,
initial_model, reward_model,
actor_optim, initial_model,
critic_optim, actor_optim,
max_epochs=args.max_epochs, critic_optim,
train_batch_size=args.train_batch_size, max_epochs=args.max_epochs,
tokenizer=tokenize_fn, train_batch_size=args.train_batch_size,
max_length=128, tokenizer=tokenize_fn,
do_sample=True, max_length=128,
temperature=1.0, do_sample=True,
top_k=50, temperature=1.0,
pad_token_id=tokenizer.pad_token_id, top_k=50,
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
prepare_inputs_fn=gpt_prepare_inputs_fn, eos_token_id=tokenizer.eos_token_id,
update_model_kwargs_fn=update_model_kwargs_fn) )
trainer.fit(dataset, trainer.fit(dataset,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment