train_reward_model.py 8.13 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2
import warnings
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
3
4

import torch
5
import torch.distributed as dist
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
6
7
8
9
10
11
12
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
from coati.models.bloom import BLOOMRM
from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM
from coati.trainer import RewardModelTrainer
13
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14
15
from datasets import load_dataset
from torch.optim import Adam
16
from torch.optim.lr_scheduler import CosineAnnealingLR
17
18
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
19
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
20
21
22
23
24
25
26
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam


def train(args):
    # configure strategy
27
    if args.strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
        strategy = DDPStrategy()
29
    elif args.strategy == "colossalai_gemini":
30
        strategy = GeminiStrategy(placement_policy="auto")
31
32
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
34
35
36
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
37
38
39
40
    if args.lora_rank > 0:
        warnings.warn("Lora is not supported yet.")
        args.lora_rank = 0

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
41
    with strategy.model_init_context():
42
        if args.model == "bloom":
43
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
44
        elif args.model == "opt":
45
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
46
        elif args.model == "gpt2":
47
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
48
        elif args.model == "llama":
49
            model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
50
51
52
        else:
            raise ValueError(f'Unsupported model "{args.model}"')

53
        model.to(torch.bfloat16).to(torch.cuda.current_device())
54

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
55
56
57
58
59
        if args.model_path is not None:
            state_dict = torch.load(args.model_path)
            model.load_state_dict(state_dict)

    # configure tokenizer
60
61
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
62
        tokenizer.pad_token = tokenizer.eos_token
63
    elif args.model == "bloom":
64
        tokenizer = BloomTokenizerFast.from_pretrained(
65
66
            "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
        )
67
        tokenizer.pad_token = tokenizer.eos_token
68
69
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
70
        tokenizer.pad_token = tokenizer.eos_token
71
    elif args.model == "llama":
72
        tokenizer = LlamaTokenizer.from_pretrained(
73
74
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
        )
75
        tokenizer.eos_token = "</s>"
76
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
77
78
79
80
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

    # configure optimizer
81
    if args.strategy.startswith("colossalai"):
82
        optim = HybridAdam(model.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
83
    else:
84
        optim = Adam(model.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
85
86

    # configure loss function
87
    if args.loss_fn == "log_sig":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
88
        loss_fn = LogSigLoss()
89
    elif args.loss_fn == "log_exp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
90
91
92
93
94
95
96
97
98
99
        loss_fn = LogExpLoss()
    else:
        raise ValueError(f'Unsupported loss function "{args.loss_fn}"')

    # prepare for data and dataset
    if args.subset is not None:
        data = load_dataset(args.dataset, data_dir=args.subset)
    else:
        data = load_dataset(args.dataset)

100
101
    train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
    eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
102

103
    if args.dataset == "Dahoas/rm-static":
104
105
        train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
        eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
106
    elif args.dataset == "Anthropic/hh-rlhf":
107
108
        train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
        eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
109
110
111
    else:
        raise ValueError(f'Unsupported dataset "{args.dataset}"')

112
    if dist.is_initialized() and dist.get_world_size() > 1:
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        train_sampler = DistributedSampler(
            train_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
        eval_sampler = DistributedSampler(
            eval_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
129
130
131
132
    else:
        train_sampler = None
        eval_sampler = None

133
134
135
136
137
138
139
140
141
142
143
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        batch_size=args.batch_size,
        pin_memory=True,
    )

    eval_dataloader = DataLoader(
        eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
    )
144

145
    lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
146
    strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
147
148
149
150
151
152
153
154
155
156
157
    model = strategy_dict["model"]
    optim = strategy_dict["optimizer"]
    lr_scheduler = strategy_dict["lr_scheduler"]
    trainer = RewardModelTrainer(
        model=model,
        strategy=strategy,
        optim=optim,
        lr_scheduler=lr_scheduler,
        loss_fn=loss_fn,
        max_epochs=args.max_epochs,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
158

159
160
161
162
163
164
    trainer.fit(
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        log_dir=args.log_dir,
        use_wandb=args.use_wandb,
    )
165
166
167
168
169
170
171

    if args.lora_rank > 0 and args.merge_lora_weights:
        from coati.models.lora import LORA_MANAGER

        # NOTE: set model to eval to merge LoRA weights
        LORA_MANAGER.merge_weights = True
        model.eval()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
172
    # save model checkpoint after fitting on only rank0
173
174
    state_dict = model.state_dict()
    torch.save(state_dict, args.save_path)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
175
176
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
177
178
179
        strategy.save_optimizer(
            trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
180
181


182
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
183
    parser = argparse.ArgumentParser()
184
185
186
187
188
189
190
191
192
193
194
195
    parser.add_argument(
        "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
    )
    parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--model_path", type=str, default=None)
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument(
        "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
    )
    parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
196
    parser.add_argument("--max_datasets_size", type=int, default=1000000)
197
198
199
200
201
    parser.add_argument("--save_path", type=str, default="rm_ckpt")
    parser.add_argument("--max_epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--max_len", type=int, default=512)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
202
    parser.add_argument("--merge_lora_weights", type=bool, default=True)
203
    parser.add_argument("--lr", type=float, default=9e-6)
204
    parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
205
206
    parser.add_argument("--log_dir", default="logs", type=str)
    parser.add_argument("--use_wandb", default=False, action="store_true")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
207
208
    args = parser.parse_args()
    train(args)