train_reward_model.py 6.33 KB
Newer Older
ver217's avatar
ver217 committed
1
2
3
4
import argparse

import loralib as lora
import torch
5
6
from chatgpt.dataset import HhRlhfDataset, RmStaticDataset
from chatgpt.models import LogSigLoss, LogExpLoss
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
7
8
9
10
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM
11
from chatgpt.models.deberta import DebertaRM
ver217's avatar
ver217 committed
12
from chatgpt.trainer import RewardModelTrainer
13
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
ver217's avatar
ver217 committed
14
from datasets import load_dataset
15
from random import randint
16
from torch.optim import Adam
17
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer
18
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
ver217's avatar
ver217 committed
19

20
21
from colossalai.nn.optimizer import HybridAdam

ver217's avatar
ver217 committed
22
def train(args):
23
24
25
26
27
28
29
30
31
32
33
34
35
    # configure strategy
    if args.strategy == 'naive':
        strategy = NaiveStrategy()
    elif args.strategy == 'ddp':
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
        strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
    elif args.strategy == 'colossalai_zero2':
        strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
BlueRum's avatar
BlueRum committed
36
    with strategy.model_init_context():
37
        if args.model == 'bloom':
38
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
39
        elif args.model == 'opt':
40
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
41
        elif args.model == 'gpt2':
42
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
43
44
        elif args.model == 'deberta':
            model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
45
46
        else:
            raise ValueError(f'Unsupported model "{args.model}"')
47
48
49
50
51
        
        if args.model_path is not None:
            state_dict = torch.load(args.model_path)
            model.load_state_dict(state_dict)
        
52
53
54
55
56
    # configure tokenizer
    if args.model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'bloom':
57
        tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
58
59
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
60
61
    elif args.model == 'deberta':
        tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
62
63
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
64
    max_len = args.max_len
ver217's avatar
ver217 committed
65

66
67
    # configure optimizer
    if args.strategy.startswith('colossalai'):
68
        optim = HybridAdam(model.parameters(), lr=1.5e-5)
69
    else:
70
71
72
73
74
75
76
77
78
79
        optim = Adam(model.parameters(), lr=1.5e-5)
    
    # configure loss function
    if args.loss_fn == 'log_sig':
        loss_fn = LogSigLoss()
    elif args.loss_fn == 'log_exp':
        loss_fn = LogExpLoss()
    else:
        raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
    
ver217's avatar
ver217 committed
80
    # prepare for data and dataset
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    if args.subset is not None:
        data = load_dataset(args.dataset, data_dir=args.subset)
    else:
        data = load_dataset(args.dataset)
    
    if args.test:
        train_data = data['train'].select(range(100))
        eval_data = data['test'].select(range(10)) 
    else:
        train_data = data['train']
        eval_data = data['test']
    valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data)//10)))
    
    if args.dataset == 'Dahoas/rm-static':
        train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
        valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
        eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
    elif args.dataset == 'Anthropic/hh-rlhf':
        train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
        valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
        eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
    else:
        raise ValueError(f'Unsupported dataset "{args.dataset}"')
    
ver217's avatar
ver217 committed
105
    trainer = RewardModelTrainer(model=model,
106
107
                                 strategy=strategy,
                                 optim=optim,
108
                                 loss_fn = loss_fn,
ver217's avatar
ver217 committed
109
                                 train_dataset=train_dataset,
110
                                 valid_dataset=valid_dataset,
ver217's avatar
ver217 committed
111
112
                                 eval_dataset=eval_dataset,
                                 batch_size=args.batch_size,
113
                                 max_epochs=args.max_epochs)
ver217's avatar
ver217 committed
114

115
    trainer.fit()
116
    # save model checkpoint after fitting on only rank0
117
    strategy.save_model(trainer.model, args.save_path, only_rank0=True)
118
    # save optimizer checkpoint on all ranks
119
120
    if args.need_optim_ckpt:
        strategy.save_optimizer(trainer.optimizer, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
ver217's avatar
ver217 committed
121
122
123

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
124
125
126
    parser.add_argument('--strategy',
                        choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
                        default='naive')
127
    parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom')
ver217's avatar
ver217 committed
128
    parser.add_argument('--pretrain', type=str, default=None)
129
130
131
132
133
134
135
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--dataset', type=str,
                        choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
                        default='Dahoas/rm-static')
    parser.add_argument('--subset', type=str, default=None)
    parser.add_argument('--save_path', type=str, default='rm_ckpt.pt')
BlueRum's avatar
BlueRum committed
136
    parser.add_argument('--max_epochs', type=int, default=1)
137
138
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--max_len', type=int, default=512)
ver217's avatar
ver217 committed
139
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
140
141
    parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
    parser.add_argument('--test', type=bool, default=False)
ver217's avatar
ver217 committed
142
143
    args = parser.parse_args()
    train(args)