train_sft.py 9.91 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import argparse
2
import math
3
import warnings
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
4
5
6

import torch
import torch.distributed as dist
7
8
9
10
11
from coati.dataset import SFTDataset, SupervisedDataset
from coati.models.bloom import BLOOMActor
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
12
from coati.models.chatglm import ChatGLMActor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
from coati.trainer import SFTTrainer
14
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
15
16
17
18
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
19
20
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
21
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
22
from transformers.trainer import get_scheduler
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
25
26
27
28
29
30

from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter


def train(args):
    # configure strategy
31
    if args.strategy == 'ddp':
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
32
33
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
34
        strategy = GeminiStrategy(placement_policy='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
35
    elif args.strategy == 'colossalai_zero2':
36
        strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
37
    elif args.strategy == 'colossalai_zero2_cpu':
38
        strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
39
40
41
42
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
43
44
45
    if args.lora_rank > 0:
        warnings.warn("Gradient checkpoint is disabled when using LoRA")
        args.grad_checkpoint = False
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
47
    with strategy.model_init_context():
        if args.model == 'bloom':
48
49
50
            model = BLOOMActor(pretrained=args.pretrain,
                               lora_rank=args.lora_rank,
                               checkpoint=args.grad_checkpoint)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
51
        elif args.model == 'opt':
52
53
54
            model = OPTActor(pretrained=args.pretrain,
                             lora_rank=args.lora_rank,
                             checkpoint=args.grad_checkpoint)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
55
        elif args.model == 'gpt2':
56
57
58
            model = GPTActor(pretrained=args.pretrain,
                             lora_rank=args.lora_rank,
                             checkpoint=args.grad_checkpoint)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
59
        elif args.model == 'llama':
60
61
62
            model = LlamaActor(pretrained=args.pretrain,
                               lora_rank=args.lora_rank,
                               checkpoint=args.grad_checkpoint)
63
64
        elif args.model == 'chatglm':
            model = ChatGLMActor(pretrained=args.pretrain)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
65
66
        else:
            raise ValueError(f'Unsupported model "{args.model}"')
67
68

        model.to(torch.float16).to(torch.cuda.current_device())
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
69
70
71

    # configure tokenizer
    if args.model == 'gpt2':
72
73
        tokenizer = GPT2Tokenizer.from_pretrained(
            'gpt2' if args.tokenizer is None else args.tokenizer)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
74
75
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'bloom':
76
77
        tokenizer = BloomTokenizerFast.from_pretrained(
            'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
79
80
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained(
81
            "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
82
        tokenizer.pad_token = tokenizer.eos_token
83
84
85
86
87
    elif args.model == 'llama':
        tokenizer = LlamaTokenizer.from_pretrained(
            "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
        tokenizer.eos_token = '<\s>'
        tokenizer.pad_token = tokenizer.unk_token
88
89
90
    elif args.model == 'chatglm':
        tokenizer = ChatGLMTokenizer.from_pretrained(
            "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
91
92
93
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

94
95
96
97
98
99
100
101
102
    if args.model == 'llama' and args.strategy == 'colossalai_gemini':
        # this is a hack to deal with the resized embedding
        # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
        for name, param in model.named_parameters():
            if not isinstance(param, ColoParameter):
                sub_module_name = '.'.join(name.split('.')[:-1])
                weight_name = name.split('.')[-1]
                sub_module = model.get_submodule(sub_module_name)
                setattr(sub_module, weight_name, ColoParameter(param))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
103
104
105
106
107
108
109
110
111
112
113
114
115

    # configure optimizer
    if args.strategy.startswith('colossalai'):
        optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
    else:
        optim = Adam(model.parameters(), lr=args.lr)
    logger = get_dist_logger()

    # configure dataset
    if args.dataset == 'yizhongw/self_instruct':
        train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
        eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')

116
117
        train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
        eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
118
119
120
121

    else:
        train_dataset = SupervisedDataset(tokenizer=tokenizer,
                                          data_path=args.dataset,
122
                                          max_datasets_size=args.max_datasets_size,
123
                                          max_length=args.max_len)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        eval_dataset = None

    if dist.is_initialized() and dist.get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset,
                                           shuffle=True,
                                           seed=42,
                                           drop_last=True,
                                           rank=dist.get_rank(),
                                           num_replicas=dist.get_world_size())
        if eval_dataset is not None:
            eval_sampler = DistributedSampler(eval_dataset,
                                              shuffle=False,
                                              seed=42,
                                              drop_last=False,
                                              rank=dist.get_rank(),
                                              num_replicas=dist.get_world_size())
    else:
        train_sampler = None
        eval_sampler = None

    train_dataloader = DataLoader(train_dataset,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  pin_memory=True)
    if eval_dataset is not None:
        eval_dataloader = DataLoader(eval_dataset,
                                     shuffle=(eval_sampler is None),
                                     sampler=eval_sampler,
                                     batch_size=args.batch_size,
                                     pin_memory=True)
    else:
        eval_dataloader = None

158
159
160
161
162
163
    num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
    max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
    lr_scheduler = get_scheduler("cosine",
                                 optim,
                                 num_warmup_steps=math.ceil(max_steps * 0.03),
                                 num_training_steps=max_steps)
164
    strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
165
166
167
    model = strategy_dict['model']
    optim = strategy_dict['optimizer']
    lr_scheduler = strategy_dict['lr_scheduler']
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
168
169
170
    trainer = SFTTrainer(model=model,
                         strategy=strategy,
                         optim=optim,
171
                         lr_scheduler=lr_scheduler,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
172
                         max_epochs=args.max_epochs,
173
                         accumulation_steps=args.accumulation_steps)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
174

175
176
177
178
    trainer.fit(train_dataloader=train_dataloader,
                eval_dataloader=eval_dataloader,
                logger=logger,
                use_wandb=args.use_wandb)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
179
180

    # save model checkpoint after fitting on only rank0
181
    strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
182
183
184
185
186
187
188
189
190
191
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
        strategy.save_optimizer(trainer.optimizer,
                                'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
                                only_rank0=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--strategy',
192
                        choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
193
                        default='colossalai_zero2')
194
    parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
195
    parser.add_argument('--tokenizer', type=str, default=None)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
196
197
198
199
200
201
202
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--dataset', type=str, default=None)
    parser.add_argument('--max_datasets_size', type=int, default=None)
    parser.add_argument('--save_path', type=str, default='output')
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--max_epochs', type=int, default=3)
    parser.add_argument('--batch_size', type=int, default=4)
203
    parser.add_argument('--max_len', type=int, default=512)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
204
205
206
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
    parser.add_argument('--lr', type=float, default=5e-6)
207
    parser.add_argument('--accumulation_steps', type=int, default=8)
Hongxin Liu's avatar
Hongxin Liu committed
208
    parser.add_argument('--use_wandb', default=False, action='store_true')
209
    parser.add_argument('--grad_checkpoint', default=False, action='store_true')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
210
211
    args = parser.parse_args()
    train(args)