train.py 9.78 KB
Newer Older
mashun1's avatar
omnisql  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import argparse
import os
import math
import time
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.load_sft_dataset import SFTDataset
from utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from torch.utils.data import DataLoader
from torch.optim import AdamW
from accelerate.utils import set_seed
from accelerate import Accelerator
from torch.utils.tensorboard import SummaryWriter
from peft import LoraConfig, TaskType, get_peft_model, AutoPeftModelForCausalLM
from monkey_patch_packing import monkey_patch_packing_for_model

'''
Training LLM using Huggingface Accelerate + Deepspeed.
'''

def parse_option():
    parser = argparse.ArgumentParser()
    
    # global args
    parser.add_argument('--per_device_train_batch_size', type = int, default = 4,
                        help = 'batch size per gpu device.')
    parser.add_argument('--block_size', type = int, default = 8192,
                        help = 'block size, i.e., the length of training sequences.')
    parser.add_argument('--seed', type = int, default = 42)
    parser.add_argument('--pretrained_model_name_or_path', type = str, default = "deepseek-ai/deepseek-coder-6.7b-base")
    parser.add_argument('--epochs', type = int, default = 1)
    parser.add_argument('--lr', type = float, default = 5e-5, help = "5e-5 for pre-training, 5e-6 for fine-tuning.")
    parser.add_argument('--ckpt_num', type = int, default = 20, help = "The number of ckpts during training. (uniform sampling)")
    parser.add_argument('--tensorboard_log_dir', type = str, default = "./train_logs")
    parser.add_argument('--output_ckpt_dir', type = str, default = "./ckpts")
    parser.add_argument('--mode', type = str, default = "pre-train")

    # args for supervised fine-tuning
    parser.add_argument('--sft_data_dir', type = str, default = "train_20240127.json")
    
    # args for lora tuning
    parser.add_argument('--use_lora', action = 'store_true', help = "Whether to use Lora to fine-tune the model")
    parser.add_argument('--target_modules', type = str, help = "The names of the modules to apply the adapter to")
    parser.add_argument('--r', type = int, help = "Lora attention dimension (the `rank`)")
    parser.add_argument('--lora_alpha', type = int, help = "The alpha parameter for Lora scaling")
    parser.add_argument('--lora_dropout', type = float, help = "The dropout probability for Lora layers")
    
    opt = parser.parse_args()

    return opt

def sanity_check(input, target, tokenizer):
    print("Start Sanity Check -------->")
    for t, m in zip(input, target):
        decoded = tokenizer.decode([t])
        print("%20s: %6d -> %6d" % (repr(decoded), t, m))
    print("<-------- End Sanity Check")

    assert len(input) == len(target), f"length mismatch: {len(input)} vs {len(target)}"

def checkpoint_model(accelerator, model, tokenizer, output_ckpt_dir, last_global_step):    
    '''
    Utility fuction for only checkpointing the model dictionary (i.e., only model parameters)
    '''
    ckpt_path = os.path.join(output_ckpt_dir, "ckpt-{}".format(last_global_step))
    accelerator.print("checkpointing model state dict at {}".format(ckpt_path))
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        ckpt_path, 
        is_main_process = accelerator.is_main_process, 
        save_function = accelerator.save,
        state_dict = accelerator.get_state_dict(model),
        max_shard_size = "100GB"
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(ckpt_path)
    
    return

def train(opt):
    set_seed(opt.seed)

    writer = SummaryWriter(opt.tensorboard_log_dir)
    accelerator = Accelerator()
    print("accelerator.is_main_process:", accelerator.is_main_process)
    print("accelerator.device:", accelerator.device)

    total_batch_size = opt.per_device_train_batch_size * accelerator.num_processes * accelerator.gradient_accumulation_steps
    
    accelerator.print(opt)
    accelerator.print("tokens per batch:", total_batch_size * opt.block_size)
    accelerator.print("sequences per batch:", total_batch_size)
    accelerator.print("using LLM from:", opt.pretrained_model_name_or_path)

    # packing inputs without cross-contamination attention (must use flash attention)
    monkey_patch_packing_for_model(opt.pretrained_model_name_or_path)

    tokenizer = AutoTokenizer.from_pretrained(opt.pretrained_model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is None:
            raise ValueError("please set a right eos_token_id in the tokenizer")
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    model = AutoModelForCausalLM.from_pretrained(
        opt.pretrained_model_name_or_path,
        torch_dtype = torch.bfloat16,
        trust_remote_code = True,
        attn_implementation = "flash_attention_2"
    )

    if opt.use_lora:
        target_modules = [target_module.strip() for target_module in opt.target_modules.split(',')]
        accelerator.print("Lora target_modules:", target_modules)
        peft_config = LoraConfig(
            task_type = TaskType.CAUSAL_LM, 
            target_modules = target_modules, 
            r = opt.r, 
            lora_alpha = opt.lora_alpha, 
            lora_dropout = opt.lora_dropout
        )
        model = get_peft_model(model, peft_config)
        if accelerator.is_main_process:
            model.print_trainable_parameters()
    
    # enable gradient checkpointing to save GPU memory, but this action would slowdown the training speed 20-30%.
    # in addition, gradient_checkpointing can not be enabled when using deepspeed ZERO-3
    model.gradient_checkpointing_enable()

    dataset = SFTDataset(opt.sft_data_dir, tokenizer, opt.block_size, opt.mode)
    if accelerator.is_main_process:
        sanity_check(dataset[0]["input_ids"], dataset[0]["labels"], tokenizer)
    dataloader = DataLoader(dataset, batch_size = opt.per_device_train_batch_size, shuffle = True, drop_last = True)

    num_total_batches = math.ceil(opt.epochs * math.ceil(len(dataset) / total_batch_size)) # number of total batches
    checkpointing_steps = int(num_total_batches/opt.ckpt_num)
    accelerator.print("checkpointing_steps:", checkpointing_steps)
    optimizer = AdamW(model.parameters(), lr = opt.lr, betas = (0.9, 0.95), eps = 1e-8, weight_decay = 0.1)

    num_warm_up_batches = int(num_total_batches * 0.05) # 5% of total batches for warm up
    lr_scheduler = LinearWarmupCosineAnnealingLR(
        optimizer = optimizer, 
        warmup_epochs = num_warm_up_batches * accelerator.num_processes, # * accelerator.num_processes
        max_epochs = num_total_batches* accelerator.num_processes, # * accelerator.num_processes
        warmup_start_lr = 0.0, 
        eta_min = 0.1 * opt.lr
    )

    optimizer, model, dataloader, lr_scheduler = accelerator.prepare(optimizer, model, dataloader, lr_scheduler)
    # print(type(optimizer))
    # print(type(model))
    # print(type(dataloader))
    # print(type(lr_scheduler))

    accumulation_loss = 0
    global_completed_steps = 0
    model.train()

    st = time.time()
    for epoch in range(opt.epochs):
        print("This is epoch:", epoch+1)
        for batch_idx, batch in enumerate(dataloader):
            accelerator.print(batch["input_ids"].shape)
            
            # `accelerator.accumulate(model)` aims to set right `sync_gradients` state based on the recorded training steps
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accumulation_loss += loss.detach().float()
                # when deepspeed is enabled, `accelerator.backward(loss)` is doing optimizer.step(), optimizer.zero_grad(), and grad accumulation automatically. 
                # see `if self.is_gradient_accumulation_boundary():` line in path-to-env/site-packages/deepspeed/runtime/engine.py
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            # 'accelerator.sync_gradients' checks if the accelerator has performed an optimization step on the `total_batch_size` examples
            if accelerator.sync_gradients:
                global_completed_steps += 1
                accelerator.print("GPU 0, step {}, loss {}".format(global_completed_steps, accumulation_loss / accelerator.gradient_accumulation_steps))
                accelerator.print("GPU 0, step {}, lr state dict:".format(global_completed_steps), lr_scheduler.state_dict())
                accelerator.print(time.time()-st)
                st = time.time()

                writer.add_scalar(
                    'train-loss/gpu-{}'.format(accelerator.process_index), 
                    accumulation_loss / accelerator.gradient_accumulation_steps, 
                    global_completed_steps
                )
                writer.add_scalar(
                    'learning-rate/gpu-{}'.format(accelerator.process_index), 
                    lr_scheduler.get_last_lr()[0], 
                    global_completed_steps
                )
                # reset accumulation_loss to 0
                accumulation_loss = 0

                # save checkpoints for each checkpointing_steps total batch size
                if global_completed_steps % checkpointing_steps == 0:
                    accelerator.print("after {} global training steps, save a checkpoint".format(global_completed_steps))
                    accelerator.wait_for_everyone()
                    checkpoint_model(accelerator, model, tokenizer, opt.output_ckpt_dir, global_completed_steps)

        accelerator.print("in the end of an epoch, save a checkpoint")
        accelerator.wait_for_everyone()
        checkpoint_model(accelerator, model, tokenizer, opt.output_ckpt_dir, global_completed_steps)

if __name__ == "__main__":
    opt = parse_option()
    train(opt)