train_reward_model.py 8.18 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
import argparse
from random import randint

import torch
5
import torch.distributed as dist
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
6
7
8
9
10
11
12
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
from coati.models.bloom import BLOOMRM
from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM
from coati.trainer import RewardModelTrainer
13
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14
15
from datasets import load_dataset
from torch.optim import Adam
16
from torch.optim.lr_scheduler import CosineAnnealingLR
17
18
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
19
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
20
21
22
23
24
25
26
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam


def train(args):
    # configure strategy
27
    if args.strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
        strategy = DDPStrategy()
29
30
31
32
    elif args.strategy == "colossalai_gemini":
        strategy = GeminiStrategy(placement_policy="cuda")
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
34
35
36
37
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
    with strategy.model_init_context():
38
        if args.model == "bloom":
39
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
40
        elif args.model == "opt":
41
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
42
        elif args.model == "gpt2":
43
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
44
        elif args.model == "llama":
45
            model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
47
48
        else:
            raise ValueError(f'Unsupported model "{args.model}"')

49
50
        model.to(torch.float16).to(torch.cuda.current_device())

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
51
52
53
54
55
        if args.model_path is not None:
            state_dict = torch.load(args.model_path)
            model.load_state_dict(state_dict)

    # configure tokenizer
56
57
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
58
        tokenizer.pad_token = tokenizer.eos_token
59
    elif args.model == "bloom":
60
        tokenizer = BloomTokenizerFast.from_pretrained(
61
62
            "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
        )
63
        tokenizer.pad_token = tokenizer.eos_token
64
65
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
66
        tokenizer.pad_token = tokenizer.eos_token
67
    elif args.model == "llama":
68
        tokenizer = LlamaTokenizer.from_pretrained(
69
70
71
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
        )
        tokenizer.eos_token = "<\s>"
72
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
73
74
75
76
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

    # configure optimizer
77
    if args.strategy.startswith("colossalai"):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
79
80
81
82
        optim = HybridAdam(model.parameters(), lr=5e-6)
    else:
        optim = Adam(model.parameters(), lr=5e-6)

    # configure loss function
83
    if args.loss_fn == "log_sig":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
84
        loss_fn = LogSigLoss()
85
    elif args.loss_fn == "log_exp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
86
87
88
89
90
91
92
93
94
95
96
        loss_fn = LogExpLoss()
    else:
        raise ValueError(f'Unsupported loss function "{args.loss_fn}"')

    # prepare for data and dataset
    if args.subset is not None:
        data = load_dataset(args.dataset, data_dir=args.subset)
    else:
        data = load_dataset(args.dataset)

    if args.test:
97
98
        train_data = data["train"].select(range(20))
        eval_data = data["test"].select(range(5))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
99
    else:
100
101
102
        train_data = data["train"]
        eval_data = data["test"]
    valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
103

104
    if args.dataset == "Dahoas/rm-static":
105
106
107
        train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
        valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
        eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
108
    elif args.dataset == "Anthropic/hh-rlhf":
109
110
111
        train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
        valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
        eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
112
113
114
    else:
        raise ValueError(f'Unsupported dataset "{args.dataset}"')

115
    if dist.is_initialized() and dist.get_world_size() > 1:
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        train_sampler = DistributedSampler(
            train_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
        valid_sampler = DistributedSampler(
            valid_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
        eval_sampler = DistributedSampler(
            eval_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
140
141
142
143
144
    else:
        train_sampler = None
        valid_sampler = None
        eval_sampler = None

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        batch_size=args.batch_size,
        pin_memory=True,
    )

    valid_dataloader = DataLoader(
        valid_dataset,
        shuffle=(valid_sampler is None),
        sampler=valid_sampler,
        batch_size=args.batch_size,
        pin_memory=True,
    )

    eval_dataloader = DataLoader(
        eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
    )
164

165
    lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
166
    strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
167
168
169
170
171
172
173
174
175
176
177
    model = strategy_dict["model"]
    optim = strategy_dict["optimizer"]
    lr_scheduler = strategy_dict["lr_scheduler"]
    trainer = RewardModelTrainer(
        model=model,
        strategy=strategy,
        optim=optim,
        lr_scheduler=lr_scheduler,
        loss_fn=loss_fn,
        max_epochs=args.max_epochs,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
178

179
    trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
180
    # save model checkpoint after fitting on only rank0
181
    strategy.save_model(model, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
182
183
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
184
185
186
        strategy.save_optimizer(
            trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
187
188


189
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
190
    parser = argparse.ArgumentParser()
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    parser.add_argument(
        "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
    )
    parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--model_path", type=str, default=None)
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument(
        "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
    )
    parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
    parser.add_argument("--save_path", type=str, default="rm_ckpt")
    parser.add_argument("--max_epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--max_len", type=int, default=512)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
    parser.add_argument("--test", type=bool, default=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
210
211
    args = parser.parse_args()
    train(args)