train_peft_sft.py 8.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import argparse
import os

import loralib as lora
import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMLM
from coati.models.gpt import GPTLM
from coati.models.llama import LlamaLM
from coati.models.opt import OPTLM
from coati.trainer import SFTTrainer
14
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
15
from datasets import load_dataset
16
17
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
18
19
from torch.optim import Adam
from torch.utils.data import DataLoader
20
from torch.utils.data.dataloader import default_collate
21
from torch.utils.data.distributed import DistributedSampler
22
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast
23
24
25
26
27
28
29
30
31
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter


def train(args):
    # configure strategy
32
    if args.strategy == 'ddp':
33
34
        strategy = DDPStrategy()
    elif args.strategy == 'colossalai_gemini':
35
        strategy = GeminiStrategy(placement_policy='cuda')
36
    elif args.strategy == 'colossalai_zero2':
37
        strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
38
39
40
41
42
43
44
    else:
        raise ValueError(f'Unsupported strategy "{args.strategy}"')

    # configure model
    with strategy.model_init_context():
        print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
        model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
45
46
47
        # if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
        if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
                and os.path.exists(args.save_path + '/adapter_model.bin'):
48
            print("loading from saved peft model ", args.save_path)
49
50
            model = PeftModel.from_pretrained(model, args.save_path)
        else:
51
            # we'll use peft lora library to do the lora
52
            lora_rank = args.lora_rank if args.lora_rank > 0 else 32
53
            # config lora with rank of lora_rank
54
55
56
57
58
            lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
                                     inference_mode=False,
                                     r=lora_rank,
                                     lora_alpha=32,
                                     lora_dropout=0.1)
59
60
61
62
63
64
65
66
            model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

    # configure tokenizer
    if args.model == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'bloom':
67
        tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
68
69
70
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == 'opt':
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
71
        tokenizer.pad_token = tokenizer.eos_token
72
73
74
75
76
77
78
    elif args.model == 'llama':
        tokenizer = AutoTokenizer.from_pretrained(
            args.pretrain,
            padding_side="right",
            use_fast=False,
        )
        tokenizer.eos_token = '<\s>'
79
        tokenizer.pad_token = tokenizer.unk_token
80
81
82
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

83
84
85
86
87
88
89
90
91
    if args.model == 'llama' and args.strategy == 'colossalai_gemini':
        # this is a hack to deal with the resized embedding
        # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
        for name, param in model.named_parameters():
            if not isinstance(param, ColoParameter):
                sub_module_name = '.'.join(name.split('.')[:-1])
                weight_name = name.split('.')[-1]
                sub_module = model.get_submodule(sub_module_name)
                setattr(sub_module, weight_name, ColoParameter(param))
92
93
94
95
96
97
98
99
100
101
102

    # configure optimizer
    if args.strategy.startswith('colossalai'):
        optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
    else:
        optim = Adam(model.parameters(), lr=args.lr)

    logger = get_dist_logger()
    logger.set_level('WARNING')

    # configure dataset
103
    law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
104
105
106
107
    train_dataset = law_dataset
    print(train_dataset)
    eval_dataset = None
    if args.eval_dataset is not None:
108
        eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    data_collator = default_collate
    if dist.is_initialized() and dist.get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset,
                                           shuffle=True,
                                           seed=42,
                                           drop_last=True,
                                           rank=dist.get_rank(),
                                           num_replicas=dist.get_world_size())
        if eval_dataset is not None:
            eval_sampler = DistributedSampler(eval_dataset,
                                              shuffle=False,
                                              seed=42,
                                              drop_last=False,
                                              rank=dist.get_rank(),
                                              num_replicas=dist.get_world_size())
    else:
        train_sampler = None
        eval_sampler = None

    train_dataloader = DataLoader(train_dataset,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=data_collator,
                                  pin_memory=True)
    if eval_dataset is not None:
        eval_dataloader = DataLoader(eval_dataset,
                                     shuffle=(eval_sampler is None),
                                     sampler=eval_sampler,
                                     batch_size=args.batch_size,
                                     collate_fn=data_collator,
                                     pin_memory=True)
    else:
        eval_dataloader = None

    trainer = SFTTrainer(model=model,
                         strategy=strategy,
                         optim=optim,
                         train_dataloader=train_dataloader,
                         eval_dataloader=eval_dataloader,
                         batch_size=args.batch_size,
                         max_epochs=args.max_epochs,
151
                         accumulation_steps=args.accumulation_steps)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    trainer.fit(logger=logger, log_interval=args.log_interval)

    # save model checkpoint after fitting on only rank0
    trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
    # save optimizer checkpoint on all ranks
    if args.need_optim_ckpt:
        strategy.save_optimizer(trainer.optimizer,
                                'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
                                only_rank0=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--strategy',
167
168
                        choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
                        default='ddp')
169
170
171
172
173
174
175
176
177
178
179
    parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--dataset', type=str, default=None)
    parser.add_argument('--eval_dataset', type=str, default=None)
    parser.add_argument('--save_path', type=str, default='output')
    parser.add_argument('--need_optim_ckpt', type=bool, default=False)
    parser.add_argument('--max_epochs', type=int, default=3)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
    parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
    parser.add_argument('--lr', type=float, default=5e-6)
180
    parser.add_argument('--accumulation_steps', type=int, default=8)
181
182
    parser.add_argument('--enable_peft_lora', action='store_true', default=False)
    parser.add_argument("--is_short_text", action='store_true', default=False)
183
184
    args = parser.parse_args()
    train(args)