training_multimodal.py 6.03 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os

from dataclasses import dataclass

import datasets
import torch
import transformers

from callback import EfficiencyCallback
from datasets import Image as ImageFeature
from trl import SFTTrainer

from liger_kernel.transformers import monkey_patch


@dataclass
class CustomArguments:
    model_name: str = "Qwen/Qwen2-VL-2B-Instruct"
    dataset: str = "HuggingFaceM4/the_cauldron"
    dataset_subset: str = "ai2d"
    dataset_split: str = "train"
    max_seq_length: int = 512
    dataset_text_field: str = "texts"
    use_liger: bool = False


def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module:
    if "Qwen2-VL" in model_name:
        from transformers import Qwen2VLForConditionalGeneration

        # These settings are used to reduce the memory footprint of the Qwen2-VL model,
        # which supports training/inferences on images in their native resolution. Large
        # images -> many visual tokens (a max of 16384) -> large memory consumption.
        # If fine-tuning for a real-world application, consider these values carefully.
        min_visual_tokens_per_image = 256
        max_visual_tokens_per_image = 256

        processor = transformers.AutoProcessor.from_pretrained(
            model_name,
            padding_side="left",
            truncation_side="left",
            min_pixels=min_visual_tokens_per_image * 28 * 28,  # patch size is 14x14
            max_pixels=max_visual_tokens_per_image * 28 * 28,  # 4 patches / token
        )
        processor.tokenizer.pad_token = processor.tokenizer.eos_token
        image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")

        if use_liger:
            print("Applying Liger Kernel to Qwen2-VL model")
            monkey_patch.apply_liger_kernel_to_qwen2_vl(
                # These args can be used to override the default Liger settings
                # cross_entropy=True,
                # fused_linear_cross_entropy=False,
            )

        model = Qwen2VLForConditionalGeneration.from_pretrained(
            pretrained_model_name_or_path=model_name,
            use_cache=False,
            dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            attn_implementation="sdpa",
        )
        return model, processor, image_token_id

    raise NotImplementedError(f"Model {model_name} not supported")


def _validate_and_extract_the_cauldron(examples) -> dict[str, list]:
    batch_texts = []
    batch_images = []
    for images, texts in zip(examples["images"], examples["texts"]):
        if not images:
            raise ValueError("No image found in example from the_cauldron dataset")
        if len(images) > 1:
            raise ValueError("Only one image per example is supported")
        batch_texts.extend(texts)
        batch_images.extend([images[0]] * len(texts))
    return {"texts": batch_texts, "images": batch_images}


def _format_for_convo(example, tokenizer):
    # cauldron data is already in message format {"user": ..., "assistant": ...}
    text = example["texts"]
    messages = [
        {
            "role": "user",
            "content": [{"type": "image"}, {"type": "text", "text": text["user"]}],
        },
        {"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    return {"texts": text}


def train():
    parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments))
    training_args, custom_args = parser.parse_args_into_dataclasses()
    training_args.remove_unused_columns = False  # required to not drop the image column
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}

    model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger)

    dataset = (
        datasets.load_dataset(
            custom_args.dataset,
            custom_args.dataset_subset,
            split=custom_args.dataset_split,
        )
        .map(
            _validate_and_extract_the_cauldron,
            batched=True,
            num_proc=min(os.cpu_count(), 16),
            desc="Extracting text and images",
        )
        .map(
            _format_for_convo,
            fn_kwargs={"tokenizer": processor.tokenizer},
            desc="Formatting for convo",
        )
        .cast_column("images", ImageFeature())
        .train_test_split(test_size=0.1)
    )

    train_dataset = dataset["train"]
    eval_dataset = dataset["test"]

    def collate_fn(examples):
        """
        Taken directly from the TRL documentation with minor modifications:
        https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data

        Modifications:
        1. `apply_chat_template` is used to preprocess the texts before training begins (see above)
        2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema
        3. Ignoring image tokens in the loss computation
        """
        # Get the texts and images
        texts = [example["texts"] for example in examples]
        images = [example["images"] for example in examples]

        # Tokenize the texts and process the images
        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100

        # Ignore the image token index in the loss computation
        labels[labels == image_token_id] = -100
        batch["labels"] = labels

        return batch

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        max_seq_length=custom_args.max_seq_length,
        dataset_text_field=custom_args.dataset_text_field,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=processor.tokenizer,
        callbacks=[EfficiencyCallback()],
    )
    trainer.train()


if __name__ == "__main__":
    train()