Unverified Commit 2e16f842 authored by BlueRum's avatar BlueRum Committed by GitHub
Browse files

[chatgpt]support opt & gpt for rm training (#2876)

parent c52edcf0
from typing import Optional from typing import Optional
import torch
import torch.nn as nn import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel from transformers import BloomConfig, BloomForCausalLM, BloomModel
......
...@@ -15,12 +15,16 @@ class GPTRM(RewardModel): ...@@ -15,12 +15,16 @@ class GPTRM(RewardModel):
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config. config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing. checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None, config: Optional[GPT2Config] = None,
checkpoint: bool = False) -> None: checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained) model = GPT2Model.from_pretrained(pretrained)
elif config is not None: elif config is not None:
...@@ -29,5 +33,6 @@ class GPTRM(RewardModel): ...@@ -29,5 +33,6 @@ class GPTRM(RewardModel):
model = GPT2Model(GPT2Config()) model = GPT2Model(GPT2Config())
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1) value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head) super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Optional from typing import Optional
import torch.nn as nn import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig from transformers import OPTConfig, OPTModel
from transformers.models.opt.modeling_opt import OPTModel
from .reward_model import RewardModel from .reward_model import RewardModel
...@@ -14,6 +13,7 @@ class OPTRM(RewardModel): ...@@ -14,6 +13,7 @@ class OPTRM(RewardModel):
Args: Args:
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config. config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation. lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
...@@ -21,6 +21,7 @@ class OPTRM(RewardModel): ...@@ -21,6 +21,7 @@ class OPTRM(RewardModel):
def __init__(self, def __init__(self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None, config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none') -> None:
if pretrained is not None: if pretrained is not None:
...@@ -29,5 +30,8 @@ class OPTRM(RewardModel): ...@@ -29,5 +30,8 @@ class OPTRM(RewardModel):
model = OPTModel(config) model = OPTModel(config)
else: else:
model = OPTModel(OPTConfig()) model = OPTModel(OPTConfig())
value_head = nn.Linear(model.config.hidden_size, 1) if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias) super().__init__(model, value_head, lora_rank, lora_train_bias)
...@@ -3,12 +3,13 @@ import argparse ...@@ -3,12 +3,13 @@ import argparse
import loralib as lora import loralib as lora
import torch import torch
from chatgpt.dataset import RewardDataset from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from transformers import BloomTokenizerFast from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
...@@ -27,11 +28,30 @@ def train(args): ...@@ -27,11 +28,30 @@ def train(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model # configure model
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'opt':
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
with strategy.model_init_context(): elif args.model == 'opt':
model = BLOOMRM(pretrained=args.pretrain).cuda() tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
max_len = 1024 else:
raise ValueError(f'Unsupported model "{args.model}"')
tokenizer.pad_token = tokenizer.eos_token
max_len = 512
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith('colossalai'):
...@@ -58,10 +78,10 @@ def train(args): ...@@ -58,10 +78,10 @@ def train(args):
trainer.fit(use_lora=args.lora_rank) trainer.fit(use_lora=args.lora_rank)
if args.lora_rank > 0: # save model checkpoint after fitting on only rank0
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path) strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
else: # save optimizer checkpoint on all ranks
torch.save(trainer.model, args.save_path) strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -69,6 +89,7 @@ if __name__ == '__main__': ...@@ -69,6 +89,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive') default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
......
...@@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { ...@@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2 set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2 # torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment