"...diffusion/configs/Inference/v2-inpainting-inference.yaml" did not exist on "e99edfcb51df48dec17498c60bbbd06baa293c22"
train_peft_prompts.py 9.62 KB
Newer Older
1
2
3
4
import argparse

import torch
import torch.distributed as dist
5
from coati.dataset import DataCollatorForSupervisedDataset
6
from coati.models.bloom import BLOOMRM, BLOOMCritic
7
8
9
from coati.models.gpt import GPTRM, GPTCritic
from coati.models.llama import LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTCritic
10
from coati.trainer import PPOTrainer
11
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
12
13
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
14
15
16
17
18
19
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer

from colossalai.nn.optimizer import HybridAdam
20

21
22
23

def main(args):
    # configure strategy
24
    if args.strategy == "ddp":
25
        strategy = DDPStrategy()
26
27
28
29
    elif args.strategy == "colossalai_gemini":
        strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
30
31
32
33
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    if args.rm_path is not None:
34
        state_dict = torch.load(args.rm_path, map_location="cpu")
35
36

    # configure model
37
    if args.model == "bloom":
38
        # initial_model = BLOOMActor(pretrained=args.pretrain)
39
        print("Using peft lora to load Bloom model as initial_model")
40
        initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
41
        print("Using peft lora to load Bloom model as initial_model (Done)")
42
43
44
45
46
47
48
49
    else:
        raise ValueError(f'Unsupported actor model "{args.model}"')

    if args.rm_model == None:
        rm_model_name = args.model
    else:
        rm_model_name = args.rm_model

50
    if rm_model_name == "gpt2":
51
        reward_model = GPTRM(pretrained=args.rm_pretrain)
52
    elif rm_model_name == "bloom":
53
        print("load bloom reward model ", args.rm_pretrain)
54
        reward_model = BLOOMRM(pretrained=args.rm_pretrain)
55
    elif rm_model_name == "opt":
56
        reward_model = OPTRM(pretrained=args.rm_pretrain)
57
    elif rm_model_name == "llama":
58
59
60
61
62
        reward_model = LlamaRM(pretrained=args.rm_pretrain)
    else:
        raise ValueError(f'Unsupported reward model "{rm_model_name}"')

    if args.rm_path is not None:
63
        print("Loading reward model from", args.rm_path)
64
65
        reward_model.load_state_dict(state_dict)

66
    if args.strategy != "colossalai_gemini":
67
68
69
70
        initial_model.to(torch.float16).to(torch.cuda.current_device())
        reward_model.to(torch.float16).to(torch.cuda.current_device())

    with strategy.model_init_context():
71
        if args.model == "bloom":
72
            # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
73
            print("Using peft lora to load Bloom model as Actor")
74
            actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
75
            print("Using peft lora to load Bloom model as Actor (Done)")
76
77
78
        else:
            raise ValueError(f'Unsupported actor model "{args.model}"')

79
        if rm_model_name == "gpt2":
80
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
81
        elif rm_model_name == "bloom":
82
            print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
83
84
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
            print("load bloom critic (Done) ")
85
        elif rm_model_name == "opt":
86
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
87
        elif rm_model_name == "llama":
88
89
90
91
92
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')

        if args.rm_path is not None:
93
            print("Loading reward model from", args.rm_path)
94
95
96
            critic.load_state_dict(state_dict)
            del state_dict

97
    if args.strategy != "colossalai_gemini":
98
99
100
101
        critic.to(torch.float16).to(torch.cuda.current_device())
        actor.to(torch.float16).to(torch.cuda.current_device())

    # configure optimizer
102
    if args.strategy.startswith("colossalai"):
103
104
105
106
107
108
109
        actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
        critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
    else:
        actor_optim = Adam(actor.parameters(), lr=1e-7)
        critic_optim = Adam(critic.parameters(), lr=1e-7)

    # configure tokenizer
110
    if args.model == "gpt2":
111
        tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
112
        tokenizer.pad_token = tokenizer.eos_token
113
    elif args.model == "bloom":
114
        tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
115
        tokenizer.pad_token = tokenizer.eos_token
116
    elif args.model == "opt":
117
        tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
118
        tokenizer.pad_token = tokenizer.eos_token
119
    elif args.model == "llama":
120
        tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
121
        tokenizer.eos_token = "<\s>"
122
        tokenizer.pad_token = tokenizer.unk_token
123
124
125
126
127
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

128
    prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer)
129
130
131
132
    if dist.is_initialized() and dist.get_world_size() > 1:
        prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
    else:
        prompt_sampler = None
133
134
135
    prompt_dataloader = DataLoader(
        prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
    )
136
137
138
139
140
141

    pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
    if dist.is_initialized() and dist.get_world_size() > 1:
        pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
    else:
        pretrain_sampler = None
142
143
144
145
146
147
148
    pretrain_dataloader = DataLoader(
        pretrain_dataset,
        shuffle=(pretrain_sampler is None),
        sampler=pretrain_sampler,
        batch_size=args.ptx_batch_size,
        collate_fn=data_collator,
    )
149
150
151
152

    def tokenize_fn(texts):
        # MUST padding to max length to ensure inputs of all ranks have the same length
        # Different length may lead to hang when using gemini, as different generation steps
153
        batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}

    (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))

    # configure trainer
    trainer = PPOTrainer(
        strategy,
        actor,
        critic,
        reward_model,
        initial_model,
        actor_optim,
        critic_optim,
        kl_coef=args.kl_coef,
        ptx_coef=args.ptx_coef,
        train_batch_size=args.train_batch_size,
        experience_batch_size=args.experience_batch_size,
        tokenizer=tokenize_fn,
        max_length=512,
        do_sample=True,
        temperature=1.0,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

180
181
182
183
184
185
186
    trainer.fit(
        prompt_dataloader=prompt_dataloader,
        pretrain_dataloader=pretrain_dataloader,
        num_episodes=args.num_episodes,
        num_update_steps=args.num_update_steps,
        num_collect_steps=args.num_collect_steps,
    )
187
188
189
190
191

    # save model checkpoint after fitting
    trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
192
193
194
        strategy.save_optimizer(
            actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
195
196


197
if __name__ == "__main__":
198
    parser = argparse.ArgumentParser()
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
    parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
    parser.add_argument(
        "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
    )
    parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--sft_lora_path", type=str, default=None)
    parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--rm_path", type=str, default=None)
    parser.add_argument("--rm_pretrain", type=str, default=None)
    parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument("--num_episodes", type=int, default=10)
    parser.add_argument("--num_collect_steps", type=int, default=10)
    parser.add_argument("--num_update_steps", type=int, default=5)
    parser.add_argument("--train_batch_size", type=int, default=2)
    parser.add_argument("--ptx_batch_size", type=int, default=1)
    parser.add_argument("--experience_batch_size", type=int, default=8)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument("--kl_coef", type=float, default=0.1)
    parser.add_argument("--ptx_coef", type=float, default=0.9)
221
222
    args = parser.parse_args()
    main(args)