train_prompts.py 10.6 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
3
4
import torch
import torch.distributed as dist
5
6
7
8
9
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
10
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
11
from coati.trainer import PPOTrainer
12
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
13
from coati.utils import prepare_llama_tokenizer_and_embedding
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14
15
16
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
17
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
18
19
20
21
22
23

from colossalai.nn.optimizer import HybridAdam


def main(args):
    # configure strategy
24
    if args.strategy == 'ddp':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
25
26
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
27
        strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
28
    elif args.strategy == 'colossalai_zero2':
29
        strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
30
31
32
33
34
35
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    if args.rm_path is not None:
        state_dict = torch.load(args.rm_path, map_location='cpu')

36
37
38
39
40
41
42
43
44
45
46
47
48
49
    with strategy.model_init_context():
        # configure model
        if args.model == 'gpt2':
            initial_model = GPTActor(pretrained=args.pretrain)
        elif args.model == 'bloom':
            initial_model = BLOOMActor(pretrained=args.pretrain)
        elif args.model == 'opt':
            initial_model = OPTActor(pretrained=args.pretrain)
        elif args.model == 'llama':
            initial_model = LlamaActor(pretrained=args.pretrain)
        elif args.model == 'roberta':
            initial_model = RoBERTaActor(pretrained=args.pretrain)
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
50

51
        if args.rm_model is None:
52
53
54
            rm_model_name = args.model
        else:
            rm_model_name = args.rm_model
55

56
57
58
59
60
61
62
63
64
65
66
67
        if rm_model_name == 'gpt2':
            reward_model = GPTRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'bloom':
            reward_model = BLOOMRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'opt':
            reward_model = OPTRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'llama':
            reward_model = LlamaRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'roberta':
            reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
68

69
70
        if args.rm_path is not None:
            reward_model.load_state_dict(state_dict)
71

72
73
        initial_model.to(torch.float16).to(torch.cuda.current_device())
        reward_model.to(torch.float16).to(torch.cuda.current_device())
74

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
75
76
77
78
79
80
81
82
        if args.model == 'gpt2':
            actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'bloom':
            actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'opt':
            actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
        elif args.model == 'llama':
            actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
83
84
        elif args.model == 'roberta':
            actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
85
86
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')
87

88
89
90
        if rm_model_name == 'gpt2':
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'bloom':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
91
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
92
93
94
95
        elif rm_model_name == 'opt':
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'llama':
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
96
97
        elif rm_model_name == 'roberta':
            critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
98
        else:
99
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')
100

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
101
102
103
        if args.rm_path is not None:
            critic.load_state_dict(state_dict)
            del state_dict
104

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    if args.strategy != 'colossalai_gemini':
        critic.to(torch.float16).to(torch.cuda.current_device())
        actor.to(torch.float16).to(torch.cuda.current_device())

    # configure optimizer
    if args.strategy.startswith('colossalai'):
        actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
        critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
    else:
        actor_optim = Adam(actor.parameters(), lr=1e-7)
        critic_optim = Adam(critic.parameters(), lr=1e-7)

    # configure tokenizer
    if args.model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    elif args.model == 'bloom':
        tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    elif args.model == 'llama':
        tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
        tokenizer.eos_token = '<\s>'
127
128
    elif args.model == 'roberta':
        tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
129
130
    else:
        raise ValueError(f'Unsupported model "{args.model}"')
131

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
132
133
134
135
    if args.model == 'llama':
        tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor)
    else:
        tokenizer.pad_token = tokenizer.eos_token
136

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
137
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
138

139
    prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
140
141
    if dist.is_initialized() and dist.get_world_size() > 1:
        prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
142
143
    else:
        prompt_sampler = None
144
145
146
    prompt_dataloader = DataLoader(prompt_dataset,
                                   shuffle=(prompt_sampler is None),
                                   sampler=prompt_sampler,
Hongxin Liu's avatar
Hongxin Liu committed
147
                                   batch_size=args.experience_batch_size)
148

Hongxin Liu's avatar
Hongxin Liu committed
149
150
151
152
    pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
                                         data_path=args.pretrain_dataset,
                                         max_datasets_size=16384,
                                         max_length=args.max_input_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
153
154
    if dist.is_initialized() and dist.get_world_size() > 1:
        pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
155
156
    else:
        pretrain_sampler = None
157
158
159
160
161
162
    pretrain_dataloader = DataLoader(pretrain_dataset,
                                     shuffle=(pretrain_sampler is None),
                                     sampler=pretrain_sampler,
                                     batch_size=args.ptx_batch_size,
                                     collate_fn=data_collator)

163
164
165
    # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
    (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
        strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
166
167
168
169
170
171
172
173
174
175
176
177
178

    # configure trainer
    trainer = PPOTrainer(
        strategy,
        actor,
        critic,
        reward_model,
        initial_model,
        actor_optim,
        critic_optim,
        kl_coef=args.kl_coef,
        ptx_coef=args.ptx_coef,
        train_batch_size=args.train_batch_size,
Hongxin Liu's avatar
Hongxin Liu committed
179
180
        max_length=args.max_seq_len,
        use_cache=True,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
181
182
183
184
185
        do_sample=True,
        temperature=1.0,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
186
        offload_inference_models=args.strategy != 'colossalai_gemini'
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
187
188
189
190
191
    )

    trainer.fit(prompt_dataloader=prompt_dataloader,
                pretrain_dataloader=pretrain_dataloader,
                num_episodes=args.num_episodes,
192
193
                num_collect_steps=args.num_collect_steps,
                num_update_steps=args.num_update_steps)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
194
195

    # save model checkpoint after fitting
196
    strategy.save_model(actor, args.save_path, only_rank0=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
197
198
199
200
201
202
203
204
205
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
        strategy.save_optimizer(actor_optim,
                                'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
                                only_rank0=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
206
    parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
207
208
    parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
    parser.add_argument('--strategy',
209
                        choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
210
                        default='colossalai_zero2',
211
                        help='strategy to use')
212
    parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
213
    parser.add_argument('--pretrain', type=str, default=None)
214
    parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
215
216
217
218
219
    parser.add_argument('--rm_path', type=str, default=None)
    parser.add_argument('--rm_pretrain', type=str, default=None)
    parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--num_episodes', type=int, default=10)
220
221
    parser.add_argument('--num_collect_steps', type=int, default=10)
    parser.add_argument('--num_update_steps', type=int, default=5)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
222
223
224
225
226
227
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--ptx_batch_size', type=int, default=1)
    parser.add_argument('--experience_batch_size', type=int, default=8)
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument('--kl_coef', type=float, default=0.1)
    parser.add_argument('--ptx_coef', type=float, default=0.9)
Hongxin Liu's avatar
Hongxin Liu committed
228
229
    parser.add_argument('--max_input_len', type=int, default=96)
    parser.add_argument('--max_seq_len', type=int, default=128)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
230
231
    args = parser.parse_args()
    main(args)