train_prompts.py 10.5 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2
import warnings
3

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
import torch
import torch.distributed as dist
6
from coati.dataset import PromptDataset, SupervisedDataset
7
8
9
10
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
11
from coati.trainer import PPOTrainer
12
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
14
15
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
16
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
17
18
19
20
21
22

from colossalai.nn.optimizer import HybridAdam


def main(args):
    # configure strategy
23
    if args.strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
24
        strategy = DDPStrategy()
25
    elif args.strategy == "colossalai_gemini":
26
        strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
27
28
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
29
30
31
32
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    if args.rm_path is not None:
33
34
        warnings.warn("LoRA weights should be merged with the model weights")
        state_dict = torch.load(args.rm_path, map_location="cpu")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
35

36
37
    with strategy.model_init_context():
        # configure model
38
        if args.model == "gpt2":
39
            initial_model = GPTActor(pretrained=args.pretrain)
40
        elif args.model == "bloom":
41
            initial_model = BLOOMActor(pretrained=args.pretrain)
42
        elif args.model == "opt":
43
            initial_model = OPTActor(pretrained=args.pretrain)
44
        elif args.model == "llama":
45
46
47
            initial_model = LlamaActor(pretrained=args.pretrain)
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
48

49
        if args.rm_model is None:
50
51
52
            rm_model_name = args.model
        else:
            rm_model_name = args.rm_model
53

54
        if rm_model_name == "gpt2":
55
            reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
56
        elif rm_model_name == "bloom":
57
            reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
58
        elif rm_model_name == "opt":
59
            reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
60
        elif rm_model_name == "llama":
61
            reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
62
63
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
64

65
        if args.rm_path is not None:
66
            reward_model.load_state_dict(state_dict, strict=False)
67

68
69
        initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
        reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
70

71
        if args.model == "gpt2":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
72
            actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
73
        elif args.model == "bloom":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
74
            actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
75
        elif args.model == "opt":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
76
            actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
77
        elif args.model == "llama":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
            actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
79
80
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
81

82
        if rm_model_name == "gpt2":
83
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
84
        elif rm_model_name == "bloom":
85
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
86
        elif rm_model_name == "opt":
87
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
88
        elif rm_model_name == "llama":
89
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
90
        else:
91
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
92

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
93
        if args.rm_path is not None:
94
            critic.load_state_dict(state_dict, strict=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
95
            del state_dict
96

97
98
        actor.to(torch.bfloat16).to(torch.cuda.current_device())
        critic.to(torch.bfloat16).to(torch.cuda.current_device())
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
99
100

    # configure optimizer
101
    if args.strategy.startswith("colossalai"):
102
103
        actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
        critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
104
    else:
105
106
        actor_optim = Adam(actor.parameters(), lr=args.lr)
        critic_optim = Adam(critic.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
107
108

    # configure tokenizer
109
110
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
111
        tokenizer.pad_token = tokenizer.eos_token
112
    elif args.model == "bloom":
113
        tokenizer = BloomTokenizerFast.from_pretrained(
114
115
            "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
        )
116
        tokenizer.pad_token = tokenizer.eos_token
117
118
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
119
        tokenizer.pad_token = tokenizer.eos_token
120
    elif args.model == "llama":
121
        tokenizer = LlamaTokenizer.from_pretrained(
122
123
124
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
        )
        tokenizer.eos_token = "<\s>"
125
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
126
127
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
128
129
130
131
132
133
134
135
136
    # NOTE: generate() requires padding_side to be "left"
    tokenizer.padding_side = "left"

    prompt_dataset = PromptDataset(
        tokenizer=tokenizer,
        data_path=args.prompt_dataset,
        max_datasets_size=args.max_datasets_size,
        max_length=args.max_input_len,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
137
138
    if dist.is_initialized() and dist.get_world_size() > 1:
        prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
139
140
    else:
        prompt_sampler = None
141
142
143
144
145
    prompt_dataloader = DataLoader(
        prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
    )

    pretrain_dataset = SupervisedDataset(
146
147
148
149
        tokenizer=tokenizer,
        data_path=args.pretrain_dataset,
        max_datasets_size=args.max_datasets_size,
        max_length=args.max_input_len,
150
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
151
152
    if dist.is_initialized() and dist.get_world_size() > 1:
        pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
153
154
    else:
        pretrain_sampler = None
155
156
157
    pretrain_dataloader = DataLoader(
        pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
    )
158

159
    # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
160
161
162
    (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
        (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
163
164
165
166
167
168
169
170
171
172

    # configure trainer
    trainer = PPOTrainer(
        strategy,
        actor,
        critic,
        reward_model,
        initial_model,
        actor_optim,
        critic_optim,
173
        tokenizer=tokenizer,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
174
175
176
        kl_coef=args.kl_coef,
        ptx_coef=args.ptx_coef,
        train_batch_size=args.train_batch_size,
Hongxin Liu's avatar
Hongxin Liu committed
177
178
        max_length=args.max_seq_len,
        use_cache=True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
179
180
181
        do_sample=True,
        temperature=1.0,
        top_k=50,
182
        offload_inference_models=args.strategy != "colossalai_gemini",
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
183
184
    )

185
186
187
188
    trainer.fit(
        num_episodes=args.num_episodes,
        num_collect_steps=args.num_collect_steps,
        num_update_steps=args.num_update_steps,
189
190
191
192
        prompt_dataloader=prompt_dataloader,
        pretrain_dataloader=pretrain_dataloader,
        log_dir=args.log_dir,
        use_wandb=args.use_wandb,
193
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
194

195
196
197
198
199
200
    if args.lora_rank > 0 and args.merge_lora_weights:
        from coati.models.lora import LORA_MANAGER

        # NOTE: set model to eval to merge LoRA weights
        LORA_MANAGER.merge_weights = True
        actor.eval()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
201
    # save model checkpoint after fitting
202
    strategy.save_model(actor, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
203
204
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
205
206
207
        strategy.save_optimizer(
            actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
208
209


210
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
211
    parser = argparse.ArgumentParser()
212
213
    parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
    parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
214
    parser.add_argument("--max_datasets_size", type=int, default=50000)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    parser.add_argument(
        "--strategy",
        choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
        default="colossalai_zero2",
        help="strategy to use",
    )
    parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--rm_path", type=str, default=None)
    parser.add_argument("--rm_pretrain", type=str, default=None)
    parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument("--num_episodes", type=int, default=10)
    parser.add_argument("--num_collect_steps", type=int, default=10)
    parser.add_argument("--num_update_steps", type=int, default=5)
    parser.add_argument("--train_batch_size", type=int, default=8)
    parser.add_argument("--ptx_batch_size", type=int, default=1)
    parser.add_argument("--experience_batch_size", type=int, default=8)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
236
    parser.add_argument("--merge_lora_weights", type=bool, default=True)
237
    parser.add_argument("--lr", type=float, default=1e-7)
238
239
240
241
    parser.add_argument("--kl_coef", type=float, default=0.1)
    parser.add_argument("--ptx_coef", type=float, default=0.9)
    parser.add_argument("--max_input_len", type=int, default=96)
    parser.add_argument("--max_seq_len", type=int, default=128)
242
243
    parser.add_argument("--log_dir", default="logs", type=str)
    parser.add_argument("--use_wandb", default=False, action="store_true")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
244
245
    args = parser.parse_args()
    main(args)