train_reward_model.py 4.15 KB
Newer Older
ver217's avatar
ver217 committed
1
2
3
4
5
import argparse

import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
6
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
ver217's avatar
ver217 committed
7
from chatgpt.trainer import RewardModelTrainer
8
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
ver217's avatar
ver217 committed
9
from datasets import load_dataset
10
from torch.optim import Adam
11
12
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
ver217's avatar
ver217 committed
13

14
15
from colossalai.nn.optimizer import HybridAdam

ver217's avatar
ver217 committed
16
17

def train(args):
18
19
20
21
22
23
24
25
26
27
28
29
30
    # configure strategy
    if args.strategy == 'naive':
        strategy = NaiveStrategy()
    elif args.strategy == 'ddp':
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
        strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
    elif args.strategy == 'colossalai_zero2':
        strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
BlueRum's avatar
BlueRum committed
31
    with strategy.model_init_context():
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        if args.model == 'bloom':
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
        elif args.model == 'opt':
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
        elif args.model == 'gpt2':
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
        else:
            raise ValueError(f'Unsupported model "{args.model}"')

    # configure tokenizer
    if args.model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'bloom':
        tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
    tokenizer.pad_token = tokenizer.eos_token

    max_len = 512
ver217's avatar
ver217 committed
55

56
57
58
59
60
61
    # configure optimizer
    if args.strategy.startswith('colossalai'):
        optim = HybridAdam(model.parameters(), lr=5e-5)
    else:
        optim = Adam(model.parameters(), lr=5e-5)

ver217's avatar
ver217 committed
62
63
    # prepare for data and dataset
    data = load_dataset(args.dataset)
64
65
    train_data = data["train"].select(range(100))
    eval_data = data['test'].select(range(5))
ver217's avatar
ver217 committed
66
67
68
69
70
71
    train_dataset = RewardDataset(train_data, tokenizer, max_len)
    eval_dataset = RewardDataset(eval_data, tokenizer, max_len)

    # batch_size here is expected to be C(k,2), k means # response of each prompt
    # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
    trainer = RewardModelTrainer(model=model,
72
73
                                 strategy=strategy,
                                 optim=optim,
ver217's avatar
ver217 committed
74
75
76
                                 train_dataset=train_dataset,
                                 eval_dataset=eval_dataset,
                                 batch_size=args.batch_size,
77
                                 max_epochs=args.max_epochs)
ver217's avatar
ver217 committed
78
79
80

    trainer.fit(use_lora=args.lora_rank)

81
82
83
84
    # save model checkpoint after fitting on only rank0
    strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
    # save optimizer checkpoint on all ranks
    strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
ver217's avatar
ver217 committed
85
86
87
88


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
89
90
91
    parser.add_argument('--strategy',
                        choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
                        default='naive')
92
    parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
ver217's avatar
ver217 committed
93
94
95
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
    parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
BlueRum's avatar
BlueRum committed
96
97
    parser.add_argument('--max_epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=4)
ver217's avatar
ver217 committed
98
99
100
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    args = parser.parse_args()
    train(args)