peft_lora_clm_instruction_tuning.py 5.56 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import json
import os
from argparse import ArgumentParser
from functools import partial

import torch
from datasets import Dataset
from peft import TaskType
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

from auto_gptq import AutoGPTQForCausalLM, get_gptq_peft_model
from auto_gptq.utils.data_utils import collate_data, make_data_block
from auto_gptq.utils.peft_utils import GPTQLoraConfig


parser = ArgumentParser()
parser.add_argument("--model_name_or_path", type=str)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--sample_max_length", type=int, default=1024, help="max length of sample")
parser.add_argument(
    "--block_max_length",
    type=int,
    default=1024,
    help="max length of data block(bunch of samples)",
)
parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
parser.add_argument("--use_fast_tokenizer", action="store_true")
args = parser.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name_or_path = args.model_name_or_path
tokenizer_name_or_path = args.tokenizer_name_or_path or model_name_or_path

lr = args.lr
num_epochs = args.num_epochs

# creating model
peft_config = GPTQLoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=args.use_fast_tokenizer)
if not tokenizer.pad_token_id:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoGPTQForCausalLM.from_quantized(
    model_name_or_path,
    use_triton=True,
    warmup_triton=False,
    trainable=True,
    inject_fused_attention=True,
    inject_fused_mlp=False,
)
model.warmup_triton()
device = model.device
model = get_gptq_peft_model(model, peft_config=peft_config, auto_find_all_linears=True, train_mode=True)
model.print_trainable_parameters()

# loading dataset
WITH_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n"
WITHOUT_INPUT_TEMPLATE = "### Instruction:\n{instruction}\n\n### Output:\n"


def ds_refactor_fn(samples):
    instruction_data = samples["instruction"]
    input_data = samples["input"]
    output_data = samples["output"]

    new_samples = {"prompt": [], "output": []}
    for instruction_txt, input_txt, output_txt in zip(instruction_data, input_data, output_data):
        if input_txt:
            prompt = WITH_INPUT_TEMPLATE.format(instruction=instruction_txt, input=input_txt)
        else:
            prompt = WITHOUT_INPUT_TEMPLATE.format(instruction=instruction_txt)
        new_samples["prompt"].append(prompt)
        new_samples["output"].append(output_txt)

    return new_samples


ds = Dataset.from_generator(
    lambda: json.load(open("../quantization/dataset/alpaca_data_cleaned.json", "r", encoding="utf-8"))
)
ds = ds.map(
    make_data_block,
    batched=True,
    batch_size=len(ds),
    num_proc=1,
    remove_columns=ds.column_names,
    keep_in_memory=True,
    load_from_cache_file=False,
    fn_kwargs={
        "prompt_col_name": "prompt",
        "label_col_name": "output",
        "tokenizer": tokenizer,
        "preprocess_fn": ds_refactor_fn,
        "sample_max_len": args.sample_max_length,
        "block_max_len": args.block_max_length,
        "add_eos_token": True,
        "truncate_prompt": False,
        "merge_prompt_label": True,
    },
)
ds = ds.train_test_split(test_size=len(ds) // 10)
train_ds, eval_ds = ds["train"], ds["test"]
collate_fn = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=partial(collate_fn))
eval_dataloader = DataLoader(eval_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)

# optimizer and lr scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

# training and evaluation
with torch.cuda.amp.autocast():
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader)
        for step, batch in enumerate(progress_bar):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            optimizer.zero_grad()

            progress_bar.set_postfix(loss=loss.item())

        model.eval()
        eval_loss = 0
        eval_preds = []
        for step, batch in enumerate(tqdm(eval_dataloader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            loss = outputs.loss
            eval_loss += loss.detach().float()
            eval_preds.extend(
                tokenizer.batch_decode(
                    torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                    skip_special_tokens=True,
                )
            )

        eval_epoch_loss = eval_loss / len(eval_dataloader)
        eval_ppl = torch.exp(eval_epoch_loss)
        train_epoch_loss = total_loss / len(train_dataloader)
        train_ppl = torch.exp(train_epoch_loss)
        print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

model.save_pretrained(os.path.join(model_name_or_path, f"gptq_{peft_config.peft_type.value}_adapter"))