train_reward_model.py 7.7 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
import argparse

import torch
4
import torch.distributed as dist
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
5
6
7
8
9
10
11
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
from coati.models.bloom import BLOOMRM
from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM
from coati.trainer import RewardModelTrainer
12
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
14
from datasets import load_dataset
from torch.optim import Adam
15
from torch.optim.lr_scheduler import CosineAnnealingLR
16
17
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
18
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
19
20
21
22
23
24
25
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam


def train(args):
    # configure strategy
26
    if args.strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
27
        strategy = DDPStrategy()
28
    elif args.strategy == "colossalai_gemini":
29
        strategy = GeminiStrategy(placement_policy="auto")
30
31
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
32
33
34
35
36
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
    with strategy.model_init_context():
37
        if args.model == "bloom":
38
            model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
39
        elif args.model == "opt":
40
            model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
41
        elif args.model == "gpt2":
42
            model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
43
        elif args.model == "llama":
44
            model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
45
46
47
        else:
            raise ValueError(f'Unsupported model "{args.model}"')

48
        model.to(torch.bfloat16).to(torch.cuda.current_device())
49

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
50
51
52
53
54
        if args.model_path is not None:
            state_dict = torch.load(args.model_path)
            model.load_state_dict(state_dict)

    # configure tokenizer
55
56
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
57
        tokenizer.pad_token = tokenizer.eos_token
58
    elif args.model == "bloom":
59
        tokenizer = BloomTokenizerFast.from_pretrained(
60
61
            "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
        )
62
        tokenizer.pad_token = tokenizer.eos_token
63
64
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
65
        tokenizer.pad_token = tokenizer.eos_token
66
    elif args.model == "llama":
67
        tokenizer = LlamaTokenizer.from_pretrained(
68
69
70
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
        )
        tokenizer.eos_token = "<\s>"
71
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
72
73
74
75
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

    # configure optimizer
76
    if args.strategy.startswith("colossalai"):
77
        optim = HybridAdam(model.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
    else:
79
        optim = Adam(model.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
80
81

    # configure loss function
82
    if args.loss_fn == "log_sig":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
83
        loss_fn = LogSigLoss()
84
    elif args.loss_fn == "log_exp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
85
86
87
88
89
90
91
92
93
94
        loss_fn = LogExpLoss()
    else:
        raise ValueError(f'Unsupported loss function "{args.loss_fn}"')

    # prepare for data and dataset
    if args.subset is not None:
        data = load_dataset(args.dataset, data_dir=args.subset)
    else:
        data = load_dataset(args.dataset)

95
96
    train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
    eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
97

98
    if args.dataset == "Dahoas/rm-static":
99
100
        train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
        eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
101
    elif args.dataset == "Anthropic/hh-rlhf":
102
103
        train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
        eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
104
105
106
    else:
        raise ValueError(f'Unsupported dataset "{args.dataset}"')

107
    if dist.is_initialized() and dist.get_world_size() > 1:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        train_sampler = DistributedSampler(
            train_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
        eval_sampler = DistributedSampler(
            eval_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
124
125
126
127
    else:
        train_sampler = None
        eval_sampler = None

128
129
130
131
132
133
134
135
136
137
138
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        batch_size=args.batch_size,
        pin_memory=True,
    )

    eval_dataloader = DataLoader(
        eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
    )
139

140
    lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
141
    strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
142
143
144
145
146
147
148
149
150
151
152
    model = strategy_dict["model"]
    optim = strategy_dict["optimizer"]
    lr_scheduler = strategy_dict["lr_scheduler"]
    trainer = RewardModelTrainer(
        model=model,
        strategy=strategy,
        optim=optim,
        lr_scheduler=lr_scheduler,
        loss_fn=loss_fn,
        max_epochs=args.max_epochs,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
153

154
155
156
157
158
159
    trainer.fit(
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        log_dir=args.log_dir,
        use_wandb=args.use_wandb,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
160
    # save model checkpoint after fitting on only rank0
161
    strategy.save_model(model, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
162
163
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
164
165
166
        strategy.save_optimizer(
            trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
167
168


169
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
170
    parser = argparse.ArgumentParser()
171
172
173
174
175
176
177
178
179
180
181
182
    parser.add_argument(
        "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
    )
    parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--model_path", type=str, default=None)
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument(
        "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
    )
    parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
183
    parser.add_argument("--max_datasets_size", type=int, default=1000000)
184
185
186
187
188
    parser.add_argument("--save_path", type=str, default="rm_ckpt")
    parser.add_argument("--max_epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--max_len", type=int, default=512)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
189
    parser.add_argument("--lr", type=float, default=9e-6)
190
    parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
191
192
    parser.add_argument("--log_dir", default="logs", type=str)
    parser.add_argument("--use_wandb", default=False, action="store_true")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
193
194
    args = parser.parse_args()
    train(args)