"googlemock/test/gmock-spec-builders_test.cc" did not exist on "5b61ce3ee5b15e6356487dd97236bf663a96a391"
smolvlm.py 6.58 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import TYPE_CHECKING
from typing import Optional
from typing import Union

import torch

from transformers.models.smolvlm.modeling_smolvlm import SmolVLMCausalLMOutputWithPast
from transformers.processing_utils import Unpack
from transformers.utils.generic import can_return_tuple

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss

if TYPE_CHECKING:
    from transformers.cache_utils import Cache
    from transformers.utils.generic import TransformersKwargs


# Forward adapted to enable fused Linear + CE without materializing logits.
# Mirrors the pattern used for other multimodal models (e.g., InternVL, LLaVA).
@can_return_tuple
def lce_forward(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional["Cache"] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    pixel_values: Optional[torch.FloatTensor] = None,
    pixel_attention_mask: Optional[torch.BoolTensor] = None,
    image_hidden_states: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    return_dict: Optional[bool] = None,
    logits_to_keep: Union[int, torch.Tensor] = 0,
    skip_logits: Optional[bool] = None,  # Added argument for liger-kernel
    **lm_kwargs: Unpack["TransformersKwargs"],  # renamed from kwargs
) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
    r"""
    pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
        Mask to avoid performing attention on padding pixel indices.
    image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
        The hidden states of the image encoder after modality projection.
    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
        config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
        ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

    Example:

    ```python
    >>> import requests
    >>> import torch
    >>> from PIL import Image
    >>> from io import BytesIO

    >>> from transformers import AutoProcessor, AutoModelForImageTextToText
    >>> from transformers.image_utils import load_image

    >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
    >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
    >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
    >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")

    >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
    >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")

    >>> # Create inputs
    >>> messages = [
    ...     {
    ...         "role": "user",
    ...         "content": [
    ...             {"type": "video", "path": path/to/video},
    ...             {"type": "text", "text": "What is happening in this video?"},
    ...         ]
    ...     }
    ... ]

    >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)

    >>> # Generate
    >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
    >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

    >>> print(generated_texts)
    ```"""
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        pixel_values=pixel_values,
        pixel_attention_mask=pixel_attention_mask,
        image_hidden_states=image_hidden_states,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        cache_position=cache_position,
        return_dict=True,
        **lm_kwargs,
    )

    # Copied from llava.py
    hidden_states = outputs[0]
    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
    slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
    kept_hidden_states = hidden_states[:, slice_indices, :]

    shift_labels = lm_kwargs.pop("shift_labels", None)
    logits = None
    loss = None

    if skip_logits and labels is None and shift_labels is None:
        raise ValueError("skip_logits is True, but labels and shift_labels are None")

    if skip_logits is None:
        # By default, if in training mode, don't materialize logits
        skip_logits = self.training and (labels is not None or shift_labels is not None)

    if skip_logits:
        loss = LigerForCausalLMLoss(
            hidden_states=kept_hidden_states,
            lm_head_weight=self.lm_head.weight,
            labels=labels,
            shift_labels=shift_labels,
            hidden_size=self.config.text_config.hidden_size,
            **lm_kwargs,
        )

    else:
        logits = self.lm_head(kept_hidden_states)
        if labels is not None or shift_labels is not None:
            loss = self.loss_function(
                logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
            )

    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output

    return SmolVLMCausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
        image_hidden_states=outputs.image_hidden_states,
    )