train_prompts.py 10.3 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2
import warnings
3

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
import torch
import torch.distributed as dist
6
from coati.dataset import PromptDataset, SupervisedDataset
7
8
9
10
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
11
from coati.trainer import PPOTrainer
12
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
14
15
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
16
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
17
18
19
20
21
22

from colossalai.nn.optimizer import HybridAdam


def main(args):
    # configure strategy
23
    if args.strategy == 'ddp':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
24
25
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
26
        strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
27
    elif args.strategy == 'colossalai_zero2':
28
        strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
29
30
31
32
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    if args.rm_path is not None:
33
        warnings.warn('LoRA weights should be merged with the model weights')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
34
35
        state_dict = torch.load(args.rm_path, map_location='cpu')

36
37
38
39
40
41
42
43
44
45
46
47
    with strategy.model_init_context():
        # configure model
        if args.model == 'gpt2':
            initial_model = GPTActor(pretrained=args.pretrain)
        elif args.model == 'bloom':
            initial_model = BLOOMActor(pretrained=args.pretrain)
        elif args.model == 'opt':
            initial_model = OPTActor(pretrained=args.pretrain)
        elif args.model == 'llama':
            initial_model = LlamaActor(pretrained=args.pretrain)
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
48

49
        if args.rm_model is None:
50
51
52
            rm_model_name = args.model
        else:
            rm_model_name = args.rm_model
53

54
        if rm_model_name == 'gpt2':
55
            reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
56
        elif rm_model_name == 'bloom':
57
            reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
58
        elif rm_model_name == 'opt':
59
            reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
60
        elif rm_model_name == 'llama':
61
            reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
62
63
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
64

65
        if args.rm_path is not None:
66
            reward_model.load_state_dict(state_dict, strict=False)
67

68
69
        initial_model.to(torch.float16).to(torch.cuda.current_device())
        reward_model.to(torch.float16).to(torch.cuda.current_device())
70

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
71
72
73
74
75
76
77
78
        if args.model == 'gpt2':
            actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'bloom':
            actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'opt':
            actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'llama':
            actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
79
80
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
81

82
83
84
        if rm_model_name == 'gpt2':
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'bloom':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
85
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
86
87
88
89
        elif rm_model_name == 'opt':
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'llama':
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
90
        else:
91
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
92

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
93
        if args.rm_path is not None:
94
            critic.load_state_dict(state_dict, strict=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
95
            del state_dict
96

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    if args.strategy != 'colossalai_gemini':
        critic.to(torch.float16).to(torch.cuda.current_device())
        actor.to(torch.float16).to(torch.cuda.current_device())

    # configure optimizer
    if args.strategy.startswith('colossalai'):
        actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
        critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
    else:
        actor_optim = Adam(actor.parameters(), lr=1e-7)
        critic_optim = Adam(critic.parameters(), lr=1e-7)

    # configure tokenizer
    if args.model == 'gpt2':
111
112
        tokenizer = GPT2Tokenizer.from_pretrained(
            'gpt2' if args.tokenizer is None else args.tokenizer)
113
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
114
    elif args.model == 'bloom':
115
116
        tokenizer = BloomTokenizerFast.from_pretrained(
            'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
117
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
118
    elif args.model == 'opt':
119
120
        tokenizer = AutoTokenizer.from_pretrained(
            "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
121
        tokenizer.pad_token = tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
122
    elif args.model == 'llama':
123
124
        tokenizer = LlamaTokenizer.from_pretrained(
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
125
        tokenizer.eos_token = '<\s>'
126
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
127
128
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
129

130
    prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
131
132
    if dist.is_initialized() and dist.get_world_size() > 1:
        prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
133
134
    else:
        prompt_sampler = None
135
136
137
    prompt_dataloader = DataLoader(prompt_dataset,
                                   shuffle=(prompt_sampler is None),
                                   sampler=prompt_sampler,
Hongxin Liu's avatar
Hongxin Liu committed
138
                                   batch_size=args.experience_batch_size)
139

Hongxin Liu's avatar
Hongxin Liu committed
140
141
142
143
    pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
                                         data_path=args.pretrain_dataset,
                                         max_datasets_size=16384,
                                         max_length=args.max_input_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
144
145
    if dist.is_initialized() and dist.get_world_size() > 1:
        pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
146
147
    else:
        pretrain_sampler = None
148
149
150
    pretrain_dataloader = DataLoader(pretrain_dataset,
                                     shuffle=(pretrain_sampler is None),
                                     sampler=pretrain_sampler,
151
                                     batch_size=args.ptx_batch_size)
152

153
154
155
    # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
    (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
        strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
156
157
158
159
160
161
162
163
164
165
166
167
168

    # configure trainer
    trainer = PPOTrainer(
        strategy,
        actor,
        critic,
        reward_model,
        initial_model,
        actor_optim,
        critic_optim,
        kl_coef=args.kl_coef,
        ptx_coef=args.ptx_coef,
        train_batch_size=args.train_batch_size,
Hongxin Liu's avatar
Hongxin Liu committed
169
170
        max_length=args.max_seq_len,
        use_cache=True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
171
172
173
174
175
        do_sample=True,
        temperature=1.0,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
176
        offload_inference_models=args.strategy != 'colossalai_gemini'
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
177
178
179
180
181
    )

    trainer.fit(prompt_dataloader=prompt_dataloader,
                pretrain_dataloader=pretrain_dataloader,
                num_episodes=args.num_episodes,
182
183
                num_collect_steps=args.num_collect_steps,
                num_update_steps=args.num_update_steps)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
184
185

    # save model checkpoint after fitting
186
    strategy.save_model(actor, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
187
188
189
190
191
192
193
194
195
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
        strategy.save_optimizer(actor_optim,
                                'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
                                only_rank0=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
196
    parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
197
198
    parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
    parser.add_argument('--strategy',
199
                        choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
200
                        default='colossalai_zero2',
201
                        help='strategy to use')
202
    parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
203
    parser.add_argument('--tokenizer', type=str, default=None)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
204
    parser.add_argument('--pretrain', type=str, default=None)
205
    parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
206
207
208
209
210
    parser.add_argument('--rm_path', type=str, default=None)
    parser.add_argument('--rm_pretrain', type=str, default=None)
    parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--num_episodes', type=int, default=10)
211
212
    parser.add_argument('--num_collect_steps', type=int, default=10)
    parser.add_argument('--num_update_steps', type=int, default=5)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
213
214
215
216
217
218
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--ptx_batch_size', type=int, default=1)
    parser.add_argument('--experience_batch_size', type=int, default=8)
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument('--kl_coef', type=float, default=0.1)
    parser.add_argument('--ptx_coef', type=float, default=0.9)
Hongxin Liu's avatar
Hongxin Liu committed
219
220
    parser.add_argument('--max_input_len', type=int, default=96)
    parser.add_argument('--max_seq_len', type=int, default=128)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
221
222
    args = parser.parse_args()
    main(args)