fine-tune.py 3.31 KB
Newer Older
Lengyue's avatar
Lengyue committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from dataclasses import dataclass, field
from functools import partial
from typing import Optional

from datasets import load_dataset, load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorWithPadding,
    HfArgumentParser,
    Trainer,
)
from transformers import TrainingArguments as _TrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )


@dataclass
class TrainingArguments(_TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)


Lengyue's avatar
Lengyue committed
41
42
43
44
45
46
47
48
def dataset_transform(batch, tokenizer: AutoTokenizer = None):
    outputs = tokenizer(
        batch["prompt"],
        padding="longest",
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )
Lengyue's avatar
Lengyue committed
49
50
51
52
53
54
55
56
57
58
59
    labels = outputs.input_ids.clone()

    # Set the labels to -100 so that the logits are not affected by loss
    labels[outputs.attention_mask == 0] = -100

    return {
        "input_ids": outputs.input_ids,
        "attention_mask": outputs.attention_mask,
        "labels": labels,
    }

Lengyue's avatar
Lengyue committed
60

Lengyue's avatar
Lengyue committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def train():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        cache_dir=training_args.cache_dir,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=False,
        trust_remote_code=True,
        model_max_length=training_args.model_max_length,
        cache_dir=training_args.cache_dir,
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id

    if training_args.use_lora:
        from peft import LoraConfig, TaskType, get_peft_model

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["W_pack"],
            inference_mode=False,
            r=16,
            lora_alpha=64,
            lora_dropout=0.1,
        )
        model.enable_input_require_grads()
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

    try:
        dataset = load_from_disk(data_args.data_path)
Lengyue's avatar
Lengyue committed
97
98
        if "train" in dataset:
            dataset = dataset["train"]
Lengyue's avatar
Lengyue committed
99
100
    except:
        dataset = load_dataset(data_args.data_path, split="train")
Lengyue's avatar
Lengyue committed
101

Lengyue's avatar
Lengyue committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
    dataset = dataset.train_test_split(test_size=1000, seed=42)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer),
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()