train_reward_model.py 3.04 KB
Newer Older
ver217's avatar
ver217 committed
1
2
3
4
5
6
7
import argparse

import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM
from chatgpt.trainer import RewardModelTrainer
8
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
ver217's avatar
ver217 committed
9
from datasets import load_dataset
10
from torch.optim import Adam
ver217's avatar
ver217 committed
11
12
from transformers import BloomTokenizerFast

13
14
from colossalai.nn.optimizer import HybridAdam

ver217's avatar
ver217 committed
15
16

def train(args):
17
18
19
20
21
22
23
24
25
26
27
28
29
    # configure strategy
    if args.strategy == 'naive':
        strategy = NaiveStrategy()
    elif args.strategy == 'ddp':
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
        strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
    elif args.strategy == 'colossalai_zero2':
        strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
ver217's avatar
ver217 committed
30
31
    tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
    tokenizer.pad_token = tokenizer.eos_token
32
    model = BLOOMRM(pretrained=args.pretrain).cuda()
ver217's avatar
ver217 committed
33
34
    max_len = 1024

35
36
37
38
39
40
    # configure optimizer
    if args.strategy.startswith('colossalai'):
        optim = HybridAdam(model.parameters(), lr=5e-5)
    else:
        optim = Adam(model.parameters(), lr=5e-5)

ver217's avatar
ver217 committed
41
42
    # prepare for data and dataset
    data = load_dataset(args.dataset)
43
44
    train_data = data["train"].select(range(100))
    eval_data = data['test'].select(range(5))
ver217's avatar
ver217 committed
45
46
47
48
49
50
    train_dataset = RewardDataset(train_data, tokenizer, max_len)
    eval_dataset = RewardDataset(eval_data, tokenizer, max_len)

    # batch_size here is expected to be C(k,2), k means # response of each prompt
    # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
    trainer = RewardModelTrainer(model=model,
51
52
                                 strategy=strategy,
                                 optim=optim,
ver217's avatar
ver217 committed
53
54
55
                                 train_dataset=train_dataset,
                                 eval_dataset=eval_dataset,
                                 batch_size=args.batch_size,
56
                                 max_epochs=args.max_epochs)
ver217's avatar
ver217 committed
57
58
59
60
61
62
63
64
65
66
67

    trainer.fit(use_lora=args.lora_rank)

    if args.lora_rank > 0:
        torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
    else:
        torch.save(trainer.model, args.save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
68
69
70
    parser.add_argument('--strategy',
                        choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
                        default='naive')
ver217's avatar
ver217 committed
71
72
73
74
75
76
77
78
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
    parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
    parser.add_argument('--max_epochs', type=int, default=2)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    args = parser.parse_args()
    train(args)