train_prompts.py 10.6 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2
import warnings
3

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
import torch
import torch.distributed as dist
6
from coati.dataset import PromptDataset, SupervisedDataset
7
8
9
10
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
11
from coati.trainer import PPOTrainer
12
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
14
15
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
16
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
17
18
19
20
21
22

from colossalai.nn.optimizer import HybridAdam


def main(args):
    # configure strategy
23
    if args.strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
24
        strategy = DDPStrategy()
25
    elif args.strategy == "colossalai_gemini":
26
        strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
27
28
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
29
30
31
32
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    if args.rm_path is not None:
33
34
        warnings.warn("LoRA weights should be merged with the model weights")
        state_dict = torch.load(args.rm_path, map_location="cpu")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
35

36
37
38
39
    if args.lora_rank > 0:
        warnings.warn("Lora is not supported yet.")
        args.lora_rank = 0

40
41
    with strategy.model_init_context():
        # configure model
42
        if args.model == "gpt2":
43
            initial_model = GPTActor(pretrained=args.pretrain)
44
        elif args.model == "bloom":
45
            initial_model = BLOOMActor(pretrained=args.pretrain)
46
        elif args.model == "opt":
47
            initial_model = OPTActor(pretrained=args.pretrain)
48
        elif args.model == "llama":
49
50
51
            initial_model = LlamaActor(pretrained=args.pretrain)
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
52

53
        if args.rm_model is None:
54
55
56
            rm_model_name = args.model
        else:
            rm_model_name = args.rm_model
57

58
        if rm_model_name == "gpt2":
59
            reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
60
        elif rm_model_name == "bloom":
61
            reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
62
        elif rm_model_name == "opt":
63
            reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
64
        elif rm_model_name == "llama":
65
            reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
66
67
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
68

69
        if args.rm_path is not None:
70
            reward_model.load_state_dict(state_dict, strict=False)
71

72
73
        initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
        reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
74

75
        if args.model == "gpt2":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
76
            actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
77
        elif args.model == "bloom":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
            actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
79
        elif args.model == "opt":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
80
            actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
81
        elif args.model == "llama":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
82
            actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
83
84
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
85

86
        if rm_model_name == "gpt2":
87
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
88
        elif rm_model_name == "bloom":
89
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
90
        elif rm_model_name == "opt":
91
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
92
        elif rm_model_name == "llama":
93
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
94
        else:
95
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
96

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
97
        if args.rm_path is not None:
98
            critic.load_state_dict(state_dict, strict=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
99
            del state_dict
100

101
102
        actor.to(torch.bfloat16).to(torch.cuda.current_device())
        critic.to(torch.bfloat16).to(torch.cuda.current_device())
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
103
104

    # configure optimizer
105
    if args.strategy.startswith("colossalai"):
106
107
        actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
        critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
108
    else:
109
110
        actor_optim = Adam(actor.parameters(), lr=args.lr)
        critic_optim = Adam(critic.parameters(), lr=args.lr)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
111
112

    # configure tokenizer
113
114
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
115
        tokenizer.pad_token = tokenizer.eos_token
116
    elif args.model == "bloom":
117
        tokenizer = BloomTokenizerFast.from_pretrained(
118
119
            "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
        )
120
        tokenizer.pad_token = tokenizer.eos_token
121
122
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
123
        tokenizer.pad_token = tokenizer.eos_token
124
    elif args.model == "llama":
125
        tokenizer = LlamaTokenizer.from_pretrained(
126
127
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
        )
128
        tokenizer.eos_token = "</s>"
129
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
130
131
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
132
133
134
135
136
137
138
139
140
    # NOTE: generate() requires padding_side to be "left"
    tokenizer.padding_side = "left"

    prompt_dataset = PromptDataset(
        tokenizer=tokenizer,
        data_path=args.prompt_dataset,
        max_datasets_size=args.max_datasets_size,
        max_length=args.max_input_len,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
141
142
    if dist.is_initialized() and dist.get_world_size() > 1:
        prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
143
144
    else:
        prompt_sampler = None
145
146
147
148
149
    prompt_dataloader = DataLoader(
        prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
    )

    pretrain_dataset = SupervisedDataset(
150
151
152
153
        tokenizer=tokenizer,
        data_path=args.pretrain_dataset,
        max_datasets_size=args.max_datasets_size,
        max_length=args.max_input_len,
154
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
155
156
    if dist.is_initialized() and dist.get_world_size() > 1:
        pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
157
158
    else:
        pretrain_sampler = None
159
160
161
    pretrain_dataloader = DataLoader(
        pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
    )
162

163
    # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
164
165
166
    (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
        (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
167
168
169
170
171
172
173
174
175
176

    # configure trainer
    trainer = PPOTrainer(
        strategy,
        actor,
        critic,
        reward_model,
        initial_model,
        actor_optim,
        critic_optim,
177
        tokenizer=tokenizer,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
178
179
180
        kl_coef=args.kl_coef,
        ptx_coef=args.ptx_coef,
        train_batch_size=args.train_batch_size,
Hongxin Liu's avatar
Hongxin Liu committed
181
182
        max_length=args.max_seq_len,
        use_cache=True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
183
184
185
        do_sample=True,
        temperature=1.0,
        top_k=50,
186
        offload_inference_models=args.strategy != "colossalai_gemini",
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
187
188
    )

189
190
191
192
    trainer.fit(
        num_episodes=args.num_episodes,
        num_collect_steps=args.num_collect_steps,
        num_update_steps=args.num_update_steps,
193
194
195
196
        prompt_dataloader=prompt_dataloader,
        pretrain_dataloader=pretrain_dataloader,
        log_dir=args.log_dir,
        use_wandb=args.use_wandb,
197
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
198

199
200
201
202
203
204
    if args.lora_rank > 0 and args.merge_lora_weights:
        from coati.models.lora import LORA_MANAGER

        # NOTE: set model to eval to merge LoRA weights
        LORA_MANAGER.merge_weights = True
        actor.eval()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
205
    # save model checkpoint after fitting
206
    strategy.save_pretrained(actor, path=args.save_path)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
207
208
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
209
210
211
        strategy.save_optimizer(
            actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
212
213


214
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
215
    parser = argparse.ArgumentParser()
216
217
    parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
    parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
218
    parser.add_argument("--max_datasets_size", type=int, default=50000)
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    parser.add_argument(
        "--strategy",
        choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
        default="colossalai_zero2",
        help="strategy to use",
    )
    parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--rm_path", type=str, default=None)
    parser.add_argument("--rm_pretrain", type=str, default=None)
    parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument("--num_episodes", type=int, default=10)
    parser.add_argument("--num_collect_steps", type=int, default=10)
    parser.add_argument("--num_update_steps", type=int, default=5)
    parser.add_argument("--train_batch_size", type=int, default=8)
    parser.add_argument("--ptx_batch_size", type=int, default=1)
    parser.add_argument("--experience_batch_size", type=int, default=8)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
240
    parser.add_argument("--merge_lora_weights", type=bool, default=True)
241
    parser.add_argument("--lr", type=float, default=1e-7)
242
243
244
245
    parser.add_argument("--kl_coef", type=float, default=0.1)
    parser.add_argument("--ptx_coef", type=float, default=0.9)
    parser.add_argument("--max_input_len", type=int, default=96)
    parser.add_argument("--max_seq_len", type=int, default=128)
246
247
    parser.add_argument("--log_dir", default="logs", type=str)
    parser.add_argument("--use_wandb", default=False, action="store_true")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
248
249
    args = parser.parse_args()
    main(args)