Unverified Commit 9998d5ef authored by Yuanchen's avatar Yuanchen Committed by GitHub
Browse files

[chatgpt]add reward model code for deberta (#3199)


Co-authored-by: default avatarYuanchen Xu <yuanchen.xu00@gmail.com>
parent 1e1b9d2f
from .deberta_critic import DebertaCritic
from .deberta_rm import DebertaRM
__all__ = ['DebertaCritic', 'DebertaRM']
from typing import Optional
import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model
from ..base import Critic
class DebertaCritic(Critic):
"""
Deberta Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model
from ..base import RewardModel
class DebertaRM(RewardModel):
"""
Deberta Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
...@@ -88,4 +88,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ ...@@ -88,4 +88,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4 --test True --lora_rank 4
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4
rm -rf ${BASE}/rm_ckpt.pt rm -rf ${BASE}/rm_ckpt.pt
...@@ -8,12 +8,13 @@ from chatgpt.models.base import RewardModel ...@@ -8,12 +8,13 @@ from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM from chatgpt.models.opt import OPTRM
from chatgpt.models.deberta import DebertaRM
from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset from datasets import load_dataset
from random import randint from random import randint
from torch.optim import Adam from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
...@@ -39,6 +40,8 @@ def train(args): ...@@ -39,6 +40,8 @@ def train(args):
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'gpt2': elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'deberta':
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -54,6 +57,8 @@ def train(args): ...@@ -54,6 +57,8 @@ def train(args):
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'deberta':
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
max_len = args.max_len max_len = args.max_len
...@@ -119,7 +124,7 @@ if __name__ == '__main__': ...@@ -119,7 +124,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive') default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--need_optim_ckpt', type=bool, default=False)
......
set_n_least_used_CUDA_VISIBLE_DEVICES 1 set_n_least_used_CUDA_VISIBLE_DEVICES 1
python train_reward_model.py --pretrain '/home/lczht/data2/bloom-560m' \ python train_reward_model.py --pretrain 'microsoft/deberta-v3-large' \
--model 'bloom' \ --model 'deberta' \
--strategy naive \ --strategy naive \
--loss_fn 'log_exp'\ --loss_fn 'log_exp'\
--save_path 'rmstatic.pt' \ --save_path 'rmstatic.pt' \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment