training.py 9.23 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import argparse
import math
import os

from dataclasses import _MISSING_TYPE
from dataclasses import dataclass

import datasets
import lightning.pytorch as pl
import torch
import transformers

from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies import FSDPStrategy
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp import MixedPrecision
from torch.utils.data import DataLoader
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
from trl import DataCollatorForCompletionOnlyLM

from liger_kernel.transformers import AutoLigerKernelForCausalLM
from liger_kernel.utils import infer_device

_RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"}
QUESTION = "<Question>"
CHOICES = "<Choices>"


@dataclass
class Args:
    model: str = "Qwen/Qwen2-0.5B-Instruct"
    data: str = "cais/mmlu"
    output_dir: str = "mmlu_finetuning"
    max_length: int = 2048
    # for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
    batch_size: int = 4
    lr: float = 6e-6
    weight_decay: float = 0.05
    warmup_ratio: float = 0.1
    seed: int = 42
    strategy: str = "auto"
    num_gpu: int = None


def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            # Linear warmup
            return float(current_step) / float(max(1, warmup_steps))
        else:
            # Cosine annealing
            progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))

    return lr_lambda


def parse_args() -> Args:
    parser = argparse.ArgumentParser()
    for k, v in Args.__dataclass_fields__.items():
        parser.add_argument(f"--{k}", type=v.type, default=v.default)
    parsed = parser.parse_args()
    return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)})


class LanguageModel(pl.LightningModule):
    def __init__(self, args: Args, tokenizer):
        super().__init__()
        self.args = args
        self.tokenizer = tokenizer
        self.model = None

    def configure_model(self):
        # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization
        if self.model is not None:
            return
        self.model = AutoLigerKernelForCausalLM.from_pretrained(
            self.args.model, use_cache=False, ignore_mismatched_sizes=True
        )
        if self.args.strategy == "deepspeed":
            self.model.train()
            self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, attention_mask, labels=None, **kwargs):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

    def training_step(self, batch):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        loss = outputs.loss
        self.log_dict(
            {"train_loss": loss},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            rank_zero_only=True,
            sync_dist=False,
        )
        return loss

    def validation_step(self, batch):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        loss = outputs.loss
        self.log_dict(
            {"val_loss": outputs.loss},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            rank_zero_only=True,
            sync_dist=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.args.lr,
            weight_decay=self.args.weight_decay,
            fused=True,
        )
        lr_lambda = warmup_cosine_schedule(
            warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio,
            total_steps=self.trainer.estimated_stepping_batches,
            min_lr=0,
        )
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
        }


class DataModule(pl.LightningDataModule):
    def __init__(self, tokenizer, args: Args):
        super().__init__()
        self.args = args
        self.tokenizer = tokenizer
        self.response_template_str = " <Answer>"
        response_prompt = tokenizer.encode(f"{self.response_template_str}", add_special_tokens=False)
        self.collator = DataCollatorForCompletionOnlyLM(
            tokenizer=tokenizer,
            response_template=response_prompt,
            pad_to_multiple_of=16,
        )

    def formatting_func(self, example):
        output_texts = []
        for i in range(len(example["question"])):
            choices = ""
            for j in range(len(example["choices"][i])):
                choices += f"{j + 1}. {example['choices'][i][j]}; "
            s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
            s += f"{QUESTION}{example['question'][i]} "
            s += f"{CHOICES}{choices} "
            s += f"{self.response_template_str}{example['answer'][i]}"
            output_texts.append(s)
        return output_texts

    def tokenize(self, example):
        outputs = self.tokenizer(
            self.formatting_func(example),
            truncation=True,
            padding=False,
            max_length=self.args.max_length,
        )
        return {
            "input_ids": outputs["input_ids"],
            "attention_mask": outputs["attention_mask"],
        }

    def setup(self, stage) -> None:
        dataset = datasets.load_dataset(self.args.data, "auxiliary_train")
        flattened_data = [
            {
                "answer": x["train"]["answer"],
                "choices": x["train"]["choices"],
                "question": x["train"]["question"],
                "subject": x["train"]["subject"],
            }
            for x in dataset["train"]
        ]
        dataset = datasets.Dataset.from_list(flattened_data)
        dataset = dataset.train_test_split(test_size=4096, seed=self.args.seed)
        train_dataset, val_dataset = dataset["train"], dataset["test"]
        self.train_dataset = train_dataset.map(
            self.tokenize,
            remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS),
            batched=True,
            batch_size=1,
            num_proc=4,
        )
        self.val_dataset = val_dataset.map(
            self.tokenize,
            remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS),
            batched=True,
            batch_size=1,
            num_proc=4,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size,
            collate_fn=self.collator,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.args.batch_size,
            collate_fn=self.collator,
        )


def train():
    args = parse_args()
    pl.seed_everything(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    if "Meta-Llama-3-8B" in args.model:
        layers = {LlamaDecoderLayer}
    elif "Qwen2" in args.model:
        layers = {Qwen2DecoderLayer}
    else:
        layers = {}
        raise Warning(f"Unimplemented layer wrap policy for {args.model} in this example")

    if args.strategy == "fsdp":
        strategy = FSDPStrategy(
            auto_wrap_policy=layers,
            sharding_strategy="FULL_SHARD",
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            sync_module_states=True,
            activation_checkpointing_policy=layers,
            mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
            forward_prefetch=True,
        )
        precision = None
    elif args.strategy == "deepspeed":
        strategy = DeepSpeedStrategy(stage=3)
        precision = "bf16-mixed"
    elif args.strategy == "ddp":
        strategy = "ddp"
        precision = "bf16-true"
    else:
        strategy = "auto"
        precision = "bf16-true"

    device = infer_device()
    trainer = pl.Trainer(
        accelerator=device,
        strategy=strategy,
        devices=(getattr(torch, device).device_count() if args.num_gpu is None else args.num_gpu),
        default_root_dir=args.output_dir,
        log_every_n_steps=1,
        max_epochs=1,
        precision=precision,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side="left", truncation_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    data_module = DataModule(
        tokenizer=tokenizer,
        args=args,
    )
    model = LanguageModel(args=args, tokenizer=tokenizer)
    trainer.fit(model, datamodule=data_module)


if __name__ == "__main__":
    train()