train_reward_model.py 4.05 KB
Newer Older
ver217's avatar
ver217 committed
1
2
3
4
5
import argparse

import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
6
7
8
9
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM
ver217's avatar
ver217 committed
10
from chatgpt.trainer import RewardModelTrainer
11
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
ver217's avatar
ver217 committed
12
from datasets import load_dataset
13
from torch.optim import Adam
14
15
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
ver217's avatar
ver217 committed
16

17
18
from colossalai.nn.optimizer import HybridAdam

ver217's avatar
ver217 committed
19
20

def train(args):
21
22
23
24
25
26
27
28
29
30
31
32
33
    # configure strategy
    if args.strategy == 'naive':
        strategy = NaiveStrategy()
    elif args.strategy == 'ddp':
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
        strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
    elif args.strategy == 'colossalai_zero2':
        strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
BlueRum's avatar
BlueRum committed
34
    with strategy.model_init_context():
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        if args.model == 'bloom':
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
        elif args.model == 'opt':
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
        elif args.model == 'gpt2':
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
        else:
            raise ValueError(f'Unsupported model "{args.model}"')

    # configure tokenizer
    if args.model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'bloom':
        tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
    tokenizer.pad_token = tokenizer.eos_token

    max_len = 512
ver217's avatar
ver217 committed
58

59
60
61
62
63
64
    # configure optimizer
    if args.strategy.startswith('colossalai'):
        optim = HybridAdam(model.parameters(), lr=5e-5)
    else:
        optim = Adam(model.parameters(), lr=5e-5)

ver217's avatar
ver217 committed
65
66
    # prepare for data and dataset
    data = load_dataset(args.dataset)
BlueRum's avatar
BlueRum committed
67
68
    train_data = data["train"]
    eval_data = data['test']
ver217's avatar
ver217 committed
69
70
71
72
    train_dataset = RewardDataset(train_data, tokenizer, max_len)
    eval_dataset = RewardDataset(eval_data, tokenizer, max_len)

    trainer = RewardModelTrainer(model=model,
73
74
                                 strategy=strategy,
                                 optim=optim,
ver217's avatar
ver217 committed
75
76
77
                                 train_dataset=train_dataset,
                                 eval_dataset=eval_dataset,
                                 batch_size=args.batch_size,
78
                                 max_epochs=args.max_epochs)
ver217's avatar
ver217 committed
79
80
81

    trainer.fit(use_lora=args.lora_rank)

82
83
84
85
    # save model checkpoint after fitting on only rank0
    strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
    # save optimizer checkpoint on all ranks
    strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
ver217's avatar
ver217 committed
86
87
88
89


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
90
91
92
    parser.add_argument('--strategy',
                        choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
                        default='naive')
93
    parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
ver217's avatar
ver217 committed
94
95
96
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
    parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
BlueRum's avatar
BlueRum committed
97
    parser.add_argument('--max_epochs', type=int, default=1)
BlueRum's avatar
BlueRum committed
98
    parser.add_argument('--batch_size', type=int, default=4)
ver217's avatar
ver217 committed
99
100
101
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    args = parser.parse_args()
    train(args)