evs.py 14.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import typing

import torch


16
def compute_retained_tokens_count(
17
    tokens_per_frame: int, num_frames: int, q: float
18
) -> int:
19
20
21
22
23
24
    """
    Compute the number of retained tokens for a given video.
    Method ensures that we retain all the tokens from the first frame
    regardless of the pruning rate.

    Args:
25
26
        tokens_per_frame: The number of tokens per frame.
        num_frames: The total number of frames.
27
28
29
30
31
        q: The pruning rate.

    Returns:
        The number of retained tokens.
    """
32
33
34
    total_tokens = tokens_per_frame * num_frames
    evs_num_tokens = int(total_tokens * (1 - q))
    min_num_tokens = tokens_per_frame
35
36
37
38
39
    return max(min_num_tokens, evs_num_tokens)


def compute_retention_mask(
    video_embeds: torch.Tensor,
40
    video_size_thw: torch.LongTensor | tuple[int, int, int],
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    spatial_merge_size: int,
    q: float,
) -> torch.Tensor:
    """
    Computes the retention mask for input video embeddings.

    Args:
        video_embeds (`torch.Tensor`): The input video embeddings
            of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
        video_size_thw (`torch.LongTensor` of shape `(3)`):
            The temporal, height and width of video.
        spatial_merge_size: Size reduction for rows & cols dimensions.
        q: (`float`): Pruning rate factor [0,1)

    Returns:
        `torch.Tensor`: The retention mask for the video embeddings of
            `(T * H * W // spatial_merge_size ^ 2)` shape.
    """
59
    T, H, W = map(int, video_size_thw)
60
61
62
63
64
65
66
67

    # Use reshape instead of einops to avoid graph breaks
    video_embeds = video_embeds.reshape(
        T,
        H // spatial_merge_size,
        W // spatial_merge_size,
        video_embeds.size(-1),
    )
68
    tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
69
    # Core EVS
70
71
72
    similarity = torch.nn.functional.cosine_similarity(
        video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
    )
73
74
75
76
    dissimilarity = 1 - similarity

    # Always ensure we include all tokens from the first frame
    dissimilarity = torch.cat(
77
78
        [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
    )
79
80

    dissimilarity_flat = dissimilarity.view(-1)
81
82
    order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
    retain_num_tokens = compute_retained_tokens_count(
83
        tokens_per_frame=tokens_per_frame, num_frames=T, q=q
84
    )
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    topk_indices = order[:retain_num_tokens]

    retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
    retention_mask[topk_indices] = True
    retention_mask = retention_mask.reshape(dissimilarity.size())

    mask = retention_mask.view(-1)  # "T H W -> (T H W)"
    return mask


def compute_mrope_for_media(
    video_size_thw: torch.LongTensor,
    spatial_merge_size: int,
    tokens_per_second: float = 1.0,
    video_second_per_grid: float = 1.0,
) -> torch.Tensor:
    """
    Computes the mrope for video embeddings based on the grid dimensions.
    Computed mrope positions match original qwen 2.5 implementation,
    but positions are built for media being the first element in sequence.

    Args:
        video_size_thw: Media size (num frames, rows, cols)
        spatial_merge_size: Size reduction for rows & cols dimensions.
        tokens_per_second: Number of tokens per second.
        video_second_per_grid: Number of seconds per video.

    Returns:
        Tensor of shape `(T * H * W, 4)` where last dimension
        represents mrope positions [0:3), while the last channel
        contains value of llm_grid_w repeated for all positions.
    """
    llm_grid_t = video_size_thw[0]
    llm_grid_h = video_size_thw[1] // spatial_merge_size
    llm_grid_w = video_size_thw[2] // spatial_merge_size

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    t_index = (
        (
            torch.arange(llm_grid_t)
            .view(-1, 1)
            .expand(-1, llm_grid_h * llm_grid_w)
            .mul(tokens_per_second * video_second_per_grid)
        )
        .long()
        .flatten()
    )
    h_index = (
        torch.arange(llm_grid_h)
        .view(1, -1, 1)
        .expand(llm_grid_t, -1, llm_grid_w)
        .flatten()
    )
    w_index = (
        torch.arange(llm_grid_w)
        .view(1, 1, -1)
        .expand(llm_grid_t, llm_grid_h, -1)
        .flatten()
    )
    llm_grid_w = (
        torch.tensor([llm_grid_w])
        .view(1, 1, 1)
        .expand(llm_grid_t, llm_grid_h, llm_grid_w)
        .flatten()
    )
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
    return positions


def recompute_mrope_positions(
    input_ids: torch.LongTensor,
    multimodal_positions: list[torch.Tensor],
    mrope_positions: torch.LongTensor,
    num_computed_tokens: int,
    vision_start_token_id: int,
    image_token_id: int,
    video_token_id: int,
) -> tuple[torch.LongTensor, int]:
    """
    Update part of input mrope positions.
    Original mrope_positions are computed incorrectly, so once we prune media
    tokens we should reflect this in the mrope positions for the LLM.

    This method supports chunked prefill approach where
    multimodal_embeddings are passed to LLM in chunks, so input
    multimodal_embeddings may contain zero, some or even some part of all
    multimodal_embeddings for a given prompt.

173
174
175
    Each multimodal_positions has 4 or 5 extra channels
    (first 3 channels correspond to the original 3 mrope positions;
    remaining channels vary by model — see below). Provided multimodal_positions
176
177
178
179
180
181
182
183
184
185
186
187
    do not reflect location of media position in sequence - they are computed
    like the media is in the 0-th position in the sequence.

    Method works as follows: it recomputes mrope_positions starting from the
    `num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
    shifts all text tokens that goes after total_len_of_multimodal_embeddings.

    It also handles case when multimodal_embeddings is partial
    (e.g. one media is split into two prefill stages)

    Args:
        input_ids: (N,) All input tokens of the prompt (entire sequence).
188
        multimodal_positions: List of mrope positions for each media.
189
190
191
192
193
194
195
196
197
198
            If a given element is of shape (4, N), it is assumed to only describe
            positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL,
            where each multimodal input is a contiguous chunk of embeddings.
            The expected channels are [t, h, w, max_width].
            If it is of shape (5, N), it is assumed to possibly describe positions for
            both video / image embeddings, as well as text embeddings. This is the case
            of e.g. Qwen3 VL, where each video inputs are comprised of individual
            frames' embeddings, interleaved with embeddings for timestamp tokens,
            and vision start / end tokens. The expected channels are
            [t, h, w, is_vision_start, is_vision].
199
200
201
202
203
204
205
206
207
208
209
210
        mrope_positions: Existing mrope positions (4, N) for entire sequence.
        num_computed_tokens: A number of computed tokens so far.
        vision_start_token_id: Token indicating start of vision media.
        image_token_id: Image token id
        video_token_id: Video token id

    Returns:
        Tuple of (mrope_positions, mrope_position_delta).
    """

    # Tensors
    positions: torch.LongTensor = typing.cast(
211
212
        torch.LongTensor, mrope_positions.clone()
    )  # (3, N)
213
214
215
216
217
218
219
220
221
    N = input_ids.numel()

    image_mask = input_ids.eq(image_token_id)
    video_mask = input_ids.eq(video_token_id)
    media_mask = image_mask | video_mask
    text_mask = ~media_mask

    # Early exit: no media in this chunk
    if len(multimodal_positions) == 0:
222
        delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
223
224
225
226
227
228
229
230
        return positions, delta

    total_mm_tokens = torch.count_nonzero(media_mask)
    seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])

    # Early exit: we've updated positions for all media tokens
    # (and consequently - for all remaining text tokens)
    if seen_mm_tokens == total_mm_tokens:
231
        delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
232
233
        return positions, delta

234
235
236
    vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[
        0
    ]
237
238
239
240
241
242
243
244
245

    for mm_pos in multimodal_positions:
        # Each mm_pos can be a complete embedding for single media
        # or it can be a part of a single media (due to chunked prefill)

        # Cases to cover
        # - Current prefill chunk has no vision start indexes at all
        # - Vision start token appeared in previous prefill round
        # - Regular case
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        has_video_tokens = False
        num_timestamp_tokens = 0
        if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0:
            # mm_pos[4, :] indicates which positions are for video embeddings.
            # If there are no video embeddings, skip timestamp adjustment.
            has_video_tokens = torch.any(mm_pos[4, :]).item()
            if has_video_tokens:
                # Channel 3 flags VISION_START tokens.  Timestamp tokens
                # precede the first VISION_START, so its index gives us the
                # exact timestamp count.  This is robust even when early
                # frames have all their video tokens pruned (which would
                # push argmax(channel 4) far into a later frame).
                first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0]
                num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0

261
262
263
        seen_vision_start_indices = vision_start_indices[
            vision_start_indices < num_computed_tokens
        ]
264
265
266
267
268
269
270
271

        if len(seen_vision_start_indices):
            # If we have encountered some vision start indexes,
            # then we should check the condition:
            # | --- prefill 1 ------| ---- prefill 2 ----- |
            # | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
            last_vision_start_token = seen_vision_start_indices[-1]
            seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
272
273
                media_mask[:last_vision_start_token]
            )
274
            in_the_middle_of_media = (
275
276
                seen_mm_tokens > seem_mm_tokens_before_last_vision_start
            )
277
278
279
280
281
282
283
284
285
286
287
288
            # For Qwen3 VL, we can be inside a media segment even before any
            # video tokens appear (timestamp tokens are text). If we've passed
            # the last vision_start token but haven't reached the first video
            # embedding, treat this as "in the middle of media".
            if (
                not in_the_middle_of_media
                and has_video_tokens
                and num_computed_tokens > last_vision_start_token
                and num_computed_tokens
                <= last_vision_start_token + num_timestamp_tokens + 1
            ):
                in_the_middle_of_media = True
289
290

            if in_the_middle_of_media:
291
292
293
                mm_embeddings_seen = (
                    seen_mm_tokens - seem_mm_tokens_before_last_vision_start
                )
294
295
296
297
298
                global_mm_start = last_vision_start_token
            else:
                # We have completed previous mm_embedding part and
                # ready to start a new one
                next_vision_start_token = vision_start_indices[
299
300
                    vision_start_indices >= num_computed_tokens
                ][0]
301
302
303
304
305
306
307
                mm_embeddings_seen = 0
                global_mm_start = next_vision_start_token

        else:
            # If there were no vision start indexes so far,
            # let's find first vision start index
            next_vision_start_token = vision_start_indices[
308
309
                vision_start_indices >= num_computed_tokens
            ][0]
310
311
312
313

            mm_embeddings_seen = 0
            global_mm_start = next_vision_start_token

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        # For Qwen3 VL, mm_pos includes timestamp tokens before vision_start
        # when starting a new media. Adjust global_mm_start to point to where
        # the sequence actually begins (before timestamp tokens).
        adjusted_for_timestamps = False
        if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens:
            # NOTE: -1 is because there is a vision start token right after
            # timestamp tokens before any video embeddings appear.

            # Adjust global_mm_start to point to the first timestamp token
            # instead of the vision_start token.
            global_mm_start -= num_timestamp_tokens
            adjusted_for_timestamps = True

        # Offset calculation depends on whether we adjusted for timestamp tokens
        if adjusted_for_timestamps:
            # Start from position before the first timestamp token
            base = positions[-1, global_mm_start - 1] + 1
            local_start = global_mm_start + mm_embeddings_seen
        else:
            # Original logic: start after vision_start_token
            base = positions[-1, global_mm_start] + 1
            local_start = global_mm_start + 1 + mm_embeddings_seen

337
338
339
        local_end = local_start + mm_pos.shape[1]
        positions[:, local_start:local_end] = mm_pos[0:3] + base

340
341
342
343
344
345
346
        # For Qwen3 VL (5-channel), use the maximum position reached across
        # all tokens (both video and text) in all dimensions (t, h, w).
        # For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width.
        if mm_pos.shape[0] == 5:
            offset = mm_pos[0:3, :].max() + base + 1
        else:
            offset = mm_pos[3, 0] + base
347
348
349
350
351
352
353
354
355
356

        text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)

        positions[:, local_end:N] = text_pos_sum + offset - 1

        # Include distance to the next vision start token
        num_computed_tokens += mm_pos.shape[1]

    mrope_positions_delta = (positions.max() + 1 - N).item()
    return positions, mrope_positions_delta