train_sft.py 8.59 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2
import math
3
import warnings
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
6

import torch
import torch.distributed as dist
7
8
from coati.dataset import SFTDataset, SupervisedDataset
from coati.models.bloom import BLOOMActor
9
10
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
11
12
13
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14
from coati.trainer import SFTTrainer
15
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
16
17
18
19
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
20
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
21
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
22
from transformers.trainer import get_scheduler
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
25
26
27
28
29

from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam


def train(args):
    # configure strategy
30
    if args.strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
31
        strategy = DDPStrategy()
32
    elif args.strategy == "colossalai_gemini":
33
        strategy = GeminiStrategy(placement_policy="auto")
34
35
36
37
    elif args.strategy == "colossalai_zero2":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
    elif args.strategy == "colossalai_zero2_cpu":
        strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
38
39
40
41
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
42
43
44
    if args.lora_rank > 0:
        warnings.warn("Gradient checkpoint is disabled when using LoRA")
        args.grad_checkpoint = False
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
45
    with strategy.model_init_context():
46
47
48
49
50
51
52
53
54
        if args.model == "bloom":
            model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
        elif args.model == "opt":
            model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
        elif args.model == "gpt2":
            model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
        elif args.model == "llama":
            model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
        elif args.model == "chatglm":
55
            model = ChatGLMActor(pretrained=args.pretrain)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
56
57
        else:
            raise ValueError(f'Unsupported model "{args.model}"')
58

59
        model.to(torch.bfloat16).to(torch.cuda.current_device())
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
60
61

    # configure tokenizer
62
63
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
64
        tokenizer.pad_token = tokenizer.eos_token
65
    elif args.model == "bloom":
66
        tokenizer = BloomTokenizerFast.from_pretrained(
67
68
            "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
69
        tokenizer.pad_token = tokenizer.eos_token
70
71
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
72
        tokenizer.pad_token = tokenizer.eos_token
73
    elif args.model == "llama":
74
        tokenizer = LlamaTokenizer.from_pretrained(
75
76
77
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
        )
        tokenizer.eos_token = "<\s>"
78
        tokenizer.pad_token = tokenizer.unk_token
79
    elif args.model == "chatglm":
80
        tokenizer = ChatGLMTokenizer.from_pretrained(
81
82
            "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
83
84
85
86
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

    # configure optimizer
87
    if args.strategy.startswith("colossalai"):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
88
89
90
91
92
        optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
    else:
        optim = Adam(model.parameters(), lr=args.lr)

    # configure dataset
93
94
95
    if args.dataset == "yizhongw/self_instruct":
        train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
        eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
96

97
98
99
100
        if args.max_datasets_size is not None:
            train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
            eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))

101
102
        train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
        eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
103
104

    else:
105
106
107
108
109
110
        train_dataset = SupervisedDataset(
            tokenizer=tokenizer,
            data_path=args.dataset,
            max_datasets_size=args.max_datasets_size,
            max_length=args.max_len,
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
111
112
113
        eval_dataset = None

    if dist.is_initialized() and dist.get_world_size() > 1:
114
115
116
117
118
119
120
121
        train_sampler = DistributedSampler(
            train_dataset,
            shuffle=True,
            seed=42,
            drop_last=True,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
122
        if eval_dataset is not None:
123
124
125
126
127
128
129
130
            eval_sampler = DistributedSampler(
                eval_dataset,
                shuffle=False,
                seed=42,
                drop_last=False,
                rank=dist.get_rank(),
                num_replicas=dist.get_world_size(),
            )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
131
132
133
134
    else:
        train_sampler = None
        eval_sampler = None

135
136
137
138
139
140
141
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        batch_size=args.batch_size,
        pin_memory=True,
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
142
    if eval_dataset is not None:
143
144
145
146
147
148
149
        eval_dataloader = DataLoader(
            eval_dataset,
            shuffle=(eval_sampler is None),
            sampler=eval_sampler,
            batch_size=args.batch_size,
            pin_memory=True,
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
150
151
152
    else:
        eval_dataloader = None

153
154
    num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
    max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
155
156
157
    lr_scheduler = get_scheduler(
        "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
    )
158
    strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
159
160
161
162
163
164
165
166
167
168
169
170
    model = strategy_dict["model"]
    optim = strategy_dict["optimizer"]
    lr_scheduler = strategy_dict["lr_scheduler"]
    trainer = SFTTrainer(
        model=model,
        strategy=strategy,
        optim=optim,
        lr_scheduler=lr_scheduler,
        max_epochs=args.max_epochs,
        accumulation_steps=args.accumulation_steps,
    )

171
    logger = get_dist_logger()
172
    trainer.fit(
173
174
175
176
177
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        logger=logger,
        log_dir=args.log_dir,
        use_wandb=args.use_wandb,
178
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
179
180

    # save model checkpoint after fitting on only rank0
181
    strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
182
183
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
184
185
186
        strategy.save_optimizer(
            trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
187
188


189
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
190
    parser = argparse.ArgumentParser()
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    parser.add_argument(
        "--strategy",
        choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
        default="colossalai_zero2",
    )
    parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--dataset", type=str, default=None)
    parser.add_argument("--max_datasets_size", type=int, default=None)
    parser.add_argument("--save_path", type=str, default="output")
    parser.add_argument("--need_optim_ckpt", type=bool, default=False)
    parser.add_argument("--max_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_len", type=int, default=512)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument("--lr", type=float, default=5e-6)
    parser.add_argument("--accumulation_steps", type=int, default=8)
209
    parser.add_argument("--log_dir", default="logs", type=str)
210
211
    parser.add_argument("--use_wandb", default=False, action="store_true")
    parser.add_argument("--grad_checkpoint", default=False, action="store_true")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
212
213
    args = parser.parse_args()
    train(args)