"examples/vscode:/vscode.git/clone" did not exist on "e6b5a9528ee63c89fdf55d5ad6dd511647e895a0"
awq_train.py 1.81 KB
Newer Older
Casper's avatar
Casper committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import datasets
from awq import AutoAWQForCausalLM
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import get_peft_model, LoraConfig, TaskType

def prepare_split(tokenizer):
    data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train")
    prompt_template = "<s>[INST] {system} {prompt} [/INST] {output}</s>"

    def format_prompt(x):
        return prompt_template.format(
            system="",
            prompt=x["instruction"],
            output=x["output"]
        )

    data = data.map(
        lambda x: {"text": format_prompt(x)},
    ).select_columns(["text"])
    data = data.map(lambda x: tokenizer(x["text"]), batched=True)

    return data

model_path = "ybelkada/opt-125m-awq"

# Load model
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

# Prepare data
data_train = prepare_split(tokenizer)

# Config Lora
lora_config = LoraConfig(
    r=4,
    lora_alpha=8,
    lora_dropout=0.5,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False
)

model = get_peft_model(model.model, lora_config)

model.print_trainable_parameters()

training_arguments = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=1,
    optim="adamw_torch",
    num_train_epochs=1,
    learning_rate=1e-4,
    # fp16=True,
    evaluation_strategy="no",
    save_strategy="epoch",
    save_steps=100,
    logging_steps=50,
    eval_steps=None,
    load_best_model_at_end=False
)

trainer = Trainer(
    model=model,
    train_dataset=data_train,
    args=training_arguments,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()
trainer.save_model("output")