train_prompts.py 9.89 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
3
4
import torch
import torch.distributed as dist
5
6
7
8
9
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
from coati.trainer import PPOTrainer
11
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
12
13
14
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
15
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
16
17
18
19
20
21

from colossalai.nn.optimizer import HybridAdam


def main(args):
    # configure strategy
22
    if args.strategy == 'ddp':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
25
        strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
26
    elif args.strategy == 'colossalai_zero2':
27
        strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
29
30
31
32
33
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    if args.rm_path is not None:
        state_dict = torch.load(args.rm_path, map_location='cpu')

34
35
36
37
38
39
40
41
42
43
44
45
    with strategy.model_init_context():
        # configure model
        if args.model == 'gpt2':
            initial_model = GPTActor(pretrained=args.pretrain)
        elif args.model == 'bloom':
            initial_model = BLOOMActor(pretrained=args.pretrain)
        elif args.model == 'opt':
            initial_model = OPTActor(pretrained=args.pretrain)
        elif args.model == 'llama':
            initial_model = LlamaActor(pretrained=args.pretrain)
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
46

47
        if args.rm_model is None:
48
49
50
            rm_model_name = args.model
        else:
            rm_model_name = args.rm_model
51

52
53
54
55
56
57
58
59
60
61
        if rm_model_name == 'gpt2':
            reward_model = GPTRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'bloom':
            reward_model = BLOOMRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'opt':
            reward_model = OPTRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'llama':
            reward_model = LlamaRM(pretrained=args.rm_pretrain)
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
62

63
64
        if args.rm_path is not None:
            reward_model.load_state_dict(state_dict)
65

66
67
        initial_model.to(torch.float16).to(torch.cuda.current_device())
        reward_model.to(torch.float16).to(torch.cuda.current_device())
68

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
69
70
71
72
73
74
75
76
        if args.model == 'gpt2':
            actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'bloom':
            actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'opt':
            actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'llama':
            actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
77
78
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
79

80
81
82
        if rm_model_name == 'gpt2':
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'bloom':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
83
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
84
85
86
87
        elif rm_model_name == 'opt':
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'llama':
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
88
        else:
89
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
90

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
91
92
93
        if args.rm_path is not None:
            critic.load_state_dict(state_dict)
            del state_dict
94

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    if args.strategy != 'colossalai_gemini':
        critic.to(torch.float16).to(torch.cuda.current_device())
        actor.to(torch.float16).to(torch.cuda.current_device())

    # configure optimizer
    if args.strategy.startswith('colossalai'):
        actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
        critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
    else:
        actor_optim = Adam(actor.parameters(), lr=1e-7)
        critic_optim = Adam(critic.parameters(), lr=1e-7)

    # configure tokenizer
    if args.model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
110
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
111
112
    elif args.model == 'bloom':
        tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
113
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
114
115
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
116
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
117
118
119
    elif args.model == 'llama':
        tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
        tokenizer.eos_token = '<\s>'
120
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
121
122
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
123

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
124
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
125

126
    prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
127
128
    if dist.is_initialized() and dist.get_world_size() > 1:
        prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
129
130
    else:
        prompt_sampler = None
131
132
133
    prompt_dataloader = DataLoader(prompt_dataset,
                                   shuffle=(prompt_sampler is None),
                                   sampler=prompt_sampler,
Hongxin Liu's avatar
Hongxin Liu committed
134
                                   batch_size=args.experience_batch_size)
135

Hongxin Liu's avatar
Hongxin Liu committed
136
137
138
139
    pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
                                         data_path=args.pretrain_dataset,
                                         max_datasets_size=16384,
                                         max_length=args.max_input_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
140
141
    if dist.is_initialized() and dist.get_world_size() > 1:
        pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
142
143
    else:
        pretrain_sampler = None
144
145
146
147
148
149
    pretrain_dataloader = DataLoader(pretrain_dataset,
                                     shuffle=(pretrain_sampler is None),
                                     sampler=pretrain_sampler,
                                     batch_size=args.ptx_batch_size,
                                     collate_fn=data_collator)

150
151
152
    # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
    (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
        strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
153
154
155
156
157
158
159
160
161
162
163
164
165

    # configure trainer
    trainer = PPOTrainer(
        strategy,
        actor,
        critic,
        reward_model,
        initial_model,
        actor_optim,
        critic_optim,
        kl_coef=args.kl_coef,
        ptx_coef=args.ptx_coef,
        train_batch_size=args.train_batch_size,
Hongxin Liu's avatar
Hongxin Liu committed
166
167
        max_length=args.max_seq_len,
        use_cache=True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
168
169
170
171
172
        do_sample=True,
        temperature=1.0,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
173
        offload_inference_models=args.strategy != 'colossalai_gemini'
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
174
175
176
177
178
    )

    trainer.fit(prompt_dataloader=prompt_dataloader,
                pretrain_dataloader=pretrain_dataloader,
                num_episodes=args.num_episodes,
179
180
                num_collect_steps=args.num_collect_steps,
                num_update_steps=args.num_update_steps)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
181
182

    # save model checkpoint after fitting
183
    strategy.save_model(actor, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
184
185
186
187
188
189
190
191
192
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
        strategy.save_optimizer(actor_optim,
                                'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
                                only_rank0=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
193
    parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
194
195
    parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
    parser.add_argument('--strategy',
196
                        choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
197
                        default='colossalai_zero2',
198
                        help='strategy to use')
199
    parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
200
    parser.add_argument('--pretrain', type=str, default=None)
201
    parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
202
203
204
205
206
    parser.add_argument('--rm_path', type=str, default=None)
    parser.add_argument('--rm_pretrain', type=str, default=None)
    parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--num_episodes', type=int, default=10)
207
208
    parser.add_argument('--num_collect_steps', type=int, default=10)
    parser.add_argument('--num_update_steps', type=int, default=5)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
209
210
211
212
213
214
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--ptx_batch_size', type=int, default=1)
    parser.add_argument('--experience_batch_size', type=int, default=8)
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument('--kl_coef', type=float, default=0.1)
    parser.add_argument('--ptx_coef', type=float, default=0.9)
Hongxin Liu's avatar
Hongxin Liu committed
215
216
    parser.add_argument('--max_input_len', type=int, default=96)
    parser.add_argument('--max_seq_len', type=int, default=128)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
217
218
    args = parser.parse_args()
    main(args)