train_reward_model.py 9.1 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
import argparse
from random import randint

import torch
5
import torch.distributed as dist
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
6
7
8
9
10
11
12
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
from coati.models.bloom import BLOOMRM
from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM
from coati.trainer import RewardModelTrainer
13
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14
15
from datasets import load_dataset
from torch.optim import Adam
16
from torch.optim.lr_scheduler import CosineAnnealingLR
17
18
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
19
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
20
21
22
23
24
25
26
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam


def train(args):
    # configure strategy
27
    if args.strategy == 'ddp':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
29
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
30
        strategy = GeminiStrategy(placement_policy='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
31
    elif args.strategy == 'colossalai_zero2':
32
        strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
34
35
36
37
38
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
    with strategy.model_init_context():
        if args.model == 'bloom':
39
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
40
        elif args.model == 'opt':
41
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
42
        elif args.model == 'gpt2':
43
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
44
        elif args.model == 'llama':
45
            model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
47
48
        else:
            raise ValueError(f'Unsupported model "{args.model}"')

49
50
        model.to(torch.float16).to(torch.cuda.current_device())

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
51
52
53
54
55
56
        if args.model_path is not None:
            state_dict = torch.load(args.model_path)
            model.load_state_dict(state_dict)

    # configure tokenizer
    if args.model == 'gpt2':
57
58
        tokenizer = GPT2Tokenizer.from_pretrained(
            'gpt2' if args.tokenizer is None else args.tokenizer)
59
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
60
    elif args.model == 'bloom':
61
62
        tokenizer = BloomTokenizerFast.from_pretrained(
            'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
63
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
64
    elif args.model == 'opt':
65
66
        tokenizer = AutoTokenizer.from_pretrained(
            "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
67
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
68
    elif args.model == 'llama':
69
70
71
        tokenizer = LlamaTokenizer.from_pretrained(
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
        tokenizer.eos_token = '<\s>'
72
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

    # configure optimizer
    if args.strategy.startswith('colossalai'):
        optim = HybridAdam(model.parameters(), lr=5e-6)
    else:
        optim = Adam(model.parameters(), lr=5e-6)

    # configure loss function
    if args.loss_fn == 'log_sig':
        loss_fn = LogSigLoss()
    elif args.loss_fn == 'log_exp':
        loss_fn = LogExpLoss()
    else:
        raise ValueError(f'Unsupported loss function "{args.loss_fn}"')

    # prepare for data and dataset
    if args.subset is not None:
        data = load_dataset(args.dataset, data_dir=args.subset)
    else:
        data = load_dataset(args.dataset)

    if args.test:
97
98
        train_data = data['train'].select(range(20))
        eval_data = data['test'].select(range(5))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
99
100
101
102
103
104
    else:
        train_data = data['train']
        eval_data = data['test']
    valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))

    if args.dataset == 'Dahoas/rm-static':
105
106
107
        train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
        valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
        eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
108
    elif args.dataset == 'Anthropic/hh-rlhf':
109
110
111
        train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
        valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
        eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
112
113
114
    else:
        raise ValueError(f'Unsupported dataset "{args.dataset}"')

115
    if dist.is_initialized() and dist.get_world_size() > 1:
116
117
118
119
120
        train_sampler = DistributedSampler(train_dataset,
                                           shuffle=True,
                                           seed=42,
                                           drop_last=True,
                                           rank=dist.get_rank(),
121
                                           num_replicas=dist.get_world_size())
122
123
124
125
126
        valid_sampler = DistributedSampler(valid_dataset,
                                           shuffle=True,
                                           seed=42,
                                           drop_last=True,
                                           rank=dist.get_rank(),
127
                                           num_replicas=dist.get_world_size())
128
129
130
131
132
        eval_sampler = DistributedSampler(eval_dataset,
                                          shuffle=True,
                                          seed=42,
                                          drop_last=True,
                                          rank=dist.get_rank(),
133
134
135
136
137
138
139
140
141
142
143
144
                                          num_replicas=dist.get_world_size())
    else:
        train_sampler = None
        valid_sampler = None
        eval_sampler = None

    train_dataloader = DataLoader(train_dataset,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  pin_memory=True)

145
146
    valid_dataloader = DataLoader(valid_dataset,
                                  shuffle=(valid_sampler is None),
147
                                  sampler=valid_sampler,
148
149
                                  batch_size=args.batch_size,
                                  pin_memory=True)
150

151
152
153
154
155
    eval_dataloader = DataLoader(eval_dataset,
                                 shuffle=(eval_sampler is None),
                                 sampler=eval_sampler,
                                 batch_size=args.batch_size,
                                 pin_memory=True)
156

157
    lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
158
    strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
159
160
161
    model = strategy_dict['model']
    optim = strategy_dict['optimizer']
    lr_scheduler = strategy_dict['lr_scheduler']
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
162
163
164
    trainer = RewardModelTrainer(model=model,
                                 strategy=strategy,
                                 optim=optim,
165
                                 lr_scheduler=lr_scheduler,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
166
167
168
                                 loss_fn=loss_fn,
                                 max_epochs=args.max_epochs)

169
    trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
170
    # save model checkpoint after fitting on only rank0
171
    strategy.save_model(model, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
172
173
174
175
176
177
178
179
180
181
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
        strategy.save_optimizer(trainer.optimizer,
                                'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
                                only_rank0=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--strategy',
182
                        choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
183
                        default='colossalai_zero2')
184
    parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
185
    parser.add_argument('--tokenizer', type=str, default=None)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
186
187
188
189
190
191
192
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--dataset',
                        type=str,
                        choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
                        default='Dahoas/rm-static')
193
    parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
194
195
196
197
198
199
200
201
202
    parser.add_argument('--save_path', type=str, default='rm_ckpt')
    parser.add_argument('--max_epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--max_len', type=int, default=512)
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
    parser.add_argument('--test', type=bool, default=False)
    args = parser.parse_args()
    train(args)