train_prompts.py 9.87 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2
import warnings
3

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
import torch
import torch.distributed as dist
6
from coati.dataset import PromptDataset, SupervisedDataset
7
8
9
10
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
11
from coati.trainer import PPOTrainer
12
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
14
15
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
16
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
17
18
19
20
21
22

from colossalai.nn.optimizer import HybridAdam


def main(args):
    # configure strategy
23
    if args.strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
24
        strategy = DDPStrategy()
25
26
27
28
    elif args.strategy == "colossalai_gemini":
        strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
29
30
31
32
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    if args.rm_path is not None:
33
34
        warnings.warn("LoRA weights should be merged with the model weights")
        state_dict = torch.load(args.rm_path, map_location="cpu")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
35

36
37
    with strategy.model_init_context():
        # configure model
38
        if args.model == "gpt2":
39
            initial_model = GPTActor(pretrained=args.pretrain)
40
        elif args.model == "bloom":
41
            initial_model = BLOOMActor(pretrained=args.pretrain)
42
        elif args.model == "opt":
43
            initial_model = OPTActor(pretrained=args.pretrain)
44
        elif args.model == "llama":
45
46
47
            initial_model = LlamaActor(pretrained=args.pretrain)
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
48

49
        if args.rm_model is None:
50
51
52
            rm_model_name = args.model
        else:
            rm_model_name = args.rm_model
53

54
        if rm_model_name == "gpt2":
55
            reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
56
        elif rm_model_name == "bloom":
57
            reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
58
        elif rm_model_name == "opt":
59
            reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
60
        elif rm_model_name == "llama":
61
            reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
62
63
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
64

65
        if args.rm_path is not None:
66
            reward_model.load_state_dict(state_dict, strict=False)
67

68
69
        initial_model.to(torch.float16).to(torch.cuda.current_device())
        reward_model.to(torch.float16).to(torch.cuda.current_device())
70

71
        if args.model == "gpt2":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
72
            actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
73
        elif args.model == "bloom":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
74
            actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
75
        elif args.model == "opt":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
76
            actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
77
        elif args.model == "llama":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
            actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
79
80
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
81

82
        if rm_model_name == "gpt2":
83
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
84
        elif rm_model_name == "bloom":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
85
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
86
        elif rm_model_name == "opt":
87
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
88
        elif rm_model_name == "llama":
89
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
90
        else:
91
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
92

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
93
        if args.rm_path is not None:
94
            critic.load_state_dict(state_dict, strict=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
95
            del state_dict
96

97
    if args.strategy != "colossalai_gemini":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
98
99
100
101
        critic.to(torch.float16).to(torch.cuda.current_device())
        actor.to(torch.float16).to(torch.cuda.current_device())

    # configure optimizer
102
    if args.strategy.startswith("colossalai"):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
103
104
105
106
107
108
109
        actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
        critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
    else:
        actor_optim = Adam(actor.parameters(), lr=1e-7)
        critic_optim = Adam(critic.parameters(), lr=1e-7)

    # configure tokenizer
110
111
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
112
        tokenizer.pad_token = tokenizer.eos_token
113
    elif args.model == "bloom":
114
        tokenizer = BloomTokenizerFast.from_pretrained(
115
116
            "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
        )
117
        tokenizer.pad_token = tokenizer.eos_token
118
119
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
120
        tokenizer.pad_token = tokenizer.eos_token
121
    elif args.model == "llama":
122
        tokenizer = LlamaTokenizer.from_pretrained(
123
124
125
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
        )
        tokenizer.eos_token = "<\s>"
126
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
127
128
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
129

130
    prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
131
132
    if dist.is_initialized() and dist.get_world_size() > 1:
        prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
133
134
    else:
        prompt_sampler = None
135
136
137
138
139
140
141
    prompt_dataloader = DataLoader(
        prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
    )

    pretrain_dataset = SupervisedDataset(
        tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
142
143
    if dist.is_initialized() and dist.get_world_size() > 1:
        pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
144
145
    else:
        pretrain_sampler = None
146
147
148
    pretrain_dataloader = DataLoader(
        pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
    )
149

150
    # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
151
152
153
    (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
        (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
154
155
156
157
158
159
160
161
162
163
164
165
166

    # configure trainer
    trainer = PPOTrainer(
        strategy,
        actor,
        critic,
        reward_model,
        initial_model,
        actor_optim,
        critic_optim,
        kl_coef=args.kl_coef,
        ptx_coef=args.ptx_coef,
        train_batch_size=args.train_batch_size,
Hongxin Liu's avatar
Hongxin Liu committed
167
168
        max_length=args.max_seq_len,
        use_cache=True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
169
170
171
172
173
        do_sample=True,
        temperature=1.0,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
174
        offload_inference_models=args.strategy != "colossalai_gemini",
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
175
176
    )

177
178
179
180
181
182
183
    trainer.fit(
        prompt_dataloader=prompt_dataloader,
        pretrain_dataloader=pretrain_dataloader,
        num_episodes=args.num_episodes,
        num_collect_steps=args.num_collect_steps,
        num_update_steps=args.num_update_steps,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
184
185

    # save model checkpoint after fitting
186
    strategy.save_model(actor, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
187
188
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
189
190
191
        strategy.save_optimizer(
            actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
192
193


194
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
195
    parser = argparse.ArgumentParser()
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
    parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
    parser.add_argument(
        "--strategy",
        choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
        default="colossalai_zero2",
        help="strategy to use",
    )
    parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--rm_path", type=str, default=None)
    parser.add_argument("--rm_pretrain", type=str, default=None)
    parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument("--num_episodes", type=int, default=10)
    parser.add_argument("--num_collect_steps", type=int, default=10)
    parser.add_argument("--num_update_steps", type=int, default=5)
    parser.add_argument("--train_batch_size", type=int, default=8)
    parser.add_argument("--ptx_batch_size", type=int, default=1)
    parser.add_argument("--experience_batch_size", type=int, default=8)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument("--kl_coef", type=float, default=0.1)
    parser.add_argument("--ptx_coef", type=float, default=0.9)
    parser.add_argument("--max_input_len", type=int, default=96)
    parser.add_argument("--max_seq_len", type=int, default=128)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
223
224
    args = parser.parse_args()
    main(args)