train_reward_model.py 9.39 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
import argparse
from random import randint

import loralib as lora
import torch
6
import torch.distributed as dist
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
7
8
9
10
11
12
13
14
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMRM
from coati.models.deberta import DebertaRM
from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM
15
from coati.models.roberta import RoBERTaRM
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
16
17
18
19
20
from coati.trainer import RewardModelTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset
from torch.optim import Adam
21
22
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
23
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam


def train(args):
    # configure strategy
    if args.strategy == 'naive':
        strategy = NaiveStrategy()
    elif args.strategy == 'ddp':
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
        strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
    elif args.strategy == 'colossalai_zero2':
        strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
    with strategy.model_init_context():
        if args.model == 'bloom':
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
        elif args.model == 'opt':
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
        elif args.model == 'gpt2':
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
        elif args.model == 'deberta':
            model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
        elif args.model == 'llama':
            model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
54
55
        elif args.model == 'roberta':
            model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        else:
            raise ValueError(f'Unsupported model "{args.model}"')

        if args.model_path is not None:
            state_dict = torch.load(args.model_path)
            model.load_state_dict(state_dict)

    model = model.to(torch.float16)

    # configure tokenizer
    if args.model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    elif args.model == 'bloom':
        tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    elif args.model == 'deberta':
        tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
    elif args.model == 'llama':
        tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
76
77
    elif args.model == 'roberta':
        tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
    max_len = args.max_len

    if args.model == 'llama':
        tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
    else:
        tokenizer.pad_token = tokenizer.eos_token

    # configure optimizer
    if args.strategy.startswith('colossalai'):
        optim = HybridAdam(model.parameters(), lr=5e-6)
    else:
        optim = Adam(model.parameters(), lr=5e-6)

    # configure loss function
    if args.loss_fn == 'log_sig':
        loss_fn = LogSigLoss()
    elif args.loss_fn == 'log_exp':
        loss_fn = LogExpLoss()
    else:
        raise ValueError(f'Unsupported loss function "{args.loss_fn}"')

    # prepare for data and dataset
    if args.subset is not None:
        data = load_dataset(args.dataset, data_dir=args.subset)
    else:
        data = load_dataset(args.dataset)

    if args.test:
        train_data = data['train'].select(range(100))
        eval_data = data['test'].select(range(10))
    else:
        train_data = data['train']
        eval_data = data['test']
    valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))

    if args.dataset == 'Dahoas/rm-static':
        train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
        valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
        eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
    elif args.dataset == 'Anthropic/hh-rlhf':
        train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
        valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
        eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
    else:
        raise ValueError(f'Unsupported dataset "{args.dataset}"')

126
    if dist.is_initialized() and dist.get_world_size() > 1:
127
128
129
130
131
        train_sampler = DistributedSampler(train_dataset,
                                           shuffle=True,
                                           seed=42,
                                           drop_last=True,
                                           rank=dist.get_rank(),
132
                                           num_replicas=dist.get_world_size())
133
134
135
136
137
        valid_sampler = DistributedSampler(valid_dataset,
                                           shuffle=True,
                                           seed=42,
                                           drop_last=True,
                                           rank=dist.get_rank(),
138
                                           num_replicas=dist.get_world_size())
139
140
141
142
143
        eval_sampler = DistributedSampler(eval_dataset,
                                          shuffle=True,
                                          seed=42,
                                          drop_last=True,
                                          rank=dist.get_rank(),
144
145
146
147
148
149
150
151
152
153
154
155
                                          num_replicas=dist.get_world_size())
    else:
        train_sampler = None
        valid_sampler = None
        eval_sampler = None

    train_dataloader = DataLoader(train_dataset,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  pin_memory=True)

156
157
    valid_dataloader = DataLoader(valid_dataset,
                                  shuffle=(valid_sampler is None),
158
                                  sampler=valid_sampler,
159
160
                                  batch_size=args.batch_size,
                                  pin_memory=True)
161

162
163
164
165
166
    eval_dataloader = DataLoader(eval_dataset,
                                 shuffle=(eval_sampler is None),
                                 sampler=eval_sampler,
                                 batch_size=args.batch_size,
                                 pin_memory=True)
167

168
    (model, optim) = strategy.prepare((model, optim))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
169
170
171
172
    trainer = RewardModelTrainer(model=model,
                                 strategy=strategy,
                                 optim=optim,
                                 loss_fn=loss_fn,
173
174
175
                                 train_dataloader=train_dataloader,
                                 valid_dataloader=valid_dataloader,
                                 eval_dataloader=eval_dataloader,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
176
177
178
179
                                 max_epochs=args.max_epochs)

    trainer.fit()
    # save model checkpoint after fitting on only rank0
180
    strategy.save_model(model, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
181
182
183
184
185
186
187
188
189
190
191
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
        strategy.save_optimizer(trainer.optimizer,
                                'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
                                only_rank0=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--strategy',
                        choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
192
                        default='colossalai_zero2')
193
    parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--dataset',
                        type=str,
                        choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
                        default='Dahoas/rm-static')
    parser.add_argument('--subset', type=str, default=None)
    parser.add_argument('--save_path', type=str, default='rm_ckpt')
    parser.add_argument('--max_epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--max_len', type=int, default=512)
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
    parser.add_argument('--test', type=bool, default=False)
    args = parser.parse_args()
    train(args)