train_peft_sft.py 7.49 KB
Newer Older
1
2
3
4
5
6
import argparse
import os

import torch
import torch.distributed as dist
from coati.trainer import SFTTrainer
7
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
8
9
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
10
11
from torch.optim import Adam
from torch.utils.data import DataLoader
12
from torch.utils.data.dataloader import default_collate
13
from torch.utils.data.distributed import DistributedSampler
14
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast
15
16
17
18
19
20
21
22
23
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter


def train(args):
    # configure strategy
24
    if args.strategy == "ddp":
25
        strategy = DDPStrategy()
26
    elif args.strategy == "colossalai_gemini":
27
        strategy = GeminiStrategy(placement_policy="static")
28
29
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
30
31
32
33
34
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
    with strategy.model_init_context():
35
        print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
36
        model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
37
        # if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
38
39
40
41
42
        if (
            os.path.exists(args.save_path)
            and os.path.exists(args.save_path + "/adapter_config.json")
            and os.path.exists(args.save_path + "/adapter_model.bin")
        ):
43
            print("loading from saved peft model ", args.save_path)
44
45
            model = PeftModel.from_pretrained(model, args.save_path)
        else:
46
            # we'll use peft lora library to do the lora
47
            lora_rank = args.lora_rank if args.lora_rank > 0 else 32
48
            # config lora with rank of lora_rank
49
50
51
            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
            )
52
53
54
55
            model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

    # configure tokenizer
56
57
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
58
        tokenizer.pad_token = tokenizer.eos_token
59
    elif args.model == "bloom":
60
        tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
61
        tokenizer.pad_token = tokenizer.eos_token
62
    elif args.model == "opt":
63
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
64
        tokenizer.pad_token = tokenizer.eos_token
65
    elif args.model == "llama":
66
67
68
69
70
        tokenizer = AutoTokenizer.from_pretrained(
            args.pretrain,
            padding_side="right",
            use_fast=False,
        )
71
        tokenizer.eos_token = "<\s>"
72
        tokenizer.pad_token = tokenizer.unk_token
73
74
75
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

76
    if args.model == "llama" and args.strategy == "colossalai_gemini":
77
78
79
80
        # this is a hack to deal with the resized embedding
        # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
        for name, param in model.named_parameters():
            if not isinstance(param, ColoParameter):
81
82
                sub_module_name = ".".join(name.split(".")[:-1])
                weight_name = name.split(".")[-1]
83
84
                sub_module = model.get_submodule(sub_module_name)
                setattr(sub_module, weight_name, ColoParameter(param))
85
86

    # configure optimizer
87
    if args.strategy.startswith("colossalai"):
88
89
90
91
92
        optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
    else:
        optim = Adam(model.parameters(), lr=args.lr)

    logger = get_dist_logger()
93
    logger.set_level("WARNING")
94
95

    # configure dataset
96
    law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
97
98
99
100
    train_dataset = law_dataset
    print(train_dataset)
    eval_dataset = None
    if args.eval_dataset is not None:
101
        eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
102
103
    data_collator = default_collate
    if dist.is_initialized() and dist.get_world_size() > 1:
104
105
106
107
108
109
110
111
        train_sampler = DistributedSampler(
            train_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
112
        if eval_dataset is not None:
113
114
115
116
117
118
119
120
            eval_sampler = DistributedSampler(
                eval_dataset,
                shuffle=False,
                seed=42,
                drop_last=False,
                rank=dist.get_rank(),
                num_replicas=dist.get_world_size(),
            )
121
122
123
124
    else:
        train_sampler = None
        eval_sampler = None

125
126
127
128
129
130
131
132
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        batch_size=args.batch_size,
        collate_fn=data_collator,
        pin_memory=True,
    )
133
    if eval_dataset is not None:
134
135
136
137
138
139
140
141
        eval_dataloader = DataLoader(
            eval_dataset,
            shuffle=(eval_sampler is None),
            sampler=eval_sampler,
            batch_size=args.batch_size,
            collate_fn=data_collator,
            pin_memory=True,
        )
142
143
144
    else:
        eval_dataloader = None

145
146
147
148
149
150
151
152
153
154
    trainer = SFTTrainer(
        model=model,
        strategy=strategy,
        optim=optim,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        batch_size=args.batch_size,
        max_epochs=args.max_epochs,
        accumulation_steps=args.accumulation_steps,
    )
155
156
157
158
159
160
161

    trainer.fit(logger=logger, log_interval=args.log_interval)

    # save model checkpoint after fitting on only rank0
    trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
162
163
164
        strategy.save_optimizer(
            trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
165
166


167
if __name__ == "__main__":
168
    parser = argparse.ArgumentParser()
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
    parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--dataset", type=str, default=None)
    parser.add_argument("--eval_dataset", type=str, default=None)
    parser.add_argument("--save_path", type=str, default="output")
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument("--max_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
    parser.add_argument("--lr", type=float, default=5e-6)
    parser.add_argument("--accumulation_steps", type=int, default=8)
    parser.add_argument("--enable_peft_lora", action="store_true", default=False)
    parser.add_argument("--is_short_text", action="store_true", default=False)
184
185
    args = parser.parse_args()
    train(args)