rotary_embedding.py 84.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
26
import itertools
Antoni Baum's avatar
Antoni Baum committed
27
import math
28
from typing import Any, Optional, Union
29

30
import numpy as np
31
32
import torch
import torch.nn as nn
33
34
35
import triton
import triton.language as tl

Roger Wang's avatar
Roger Wang committed
36
from transformers import PretrainedConfig
37

38
from vllm.model_executor.custom_op import CustomOp
39
from vllm.platforms import current_platform
40
41
42
import vllm.envs as envs
from vllm.utils import direct_register_custom_op

43

44
if current_platform.is_cuda():
45
    from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
46
47
if current_platform.is_rocm():
    from flash_attn.layers.rotary import apply_rotary_emb
48

49

50
51
52
53
54
55
56
57
58
59
60
61
62
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)


def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)


63
def _apply_rotary_emb_torch(
64
    x: torch.Tensor,
65
66
    cos: torch.Tensor,
    sin: torch.Tensor,
67
    is_neox_style: bool,
68
) -> torch.Tensor:
69
70
71
72
73
74
75
    cos = cos.unsqueeze(-2).to(x.dtype)
    sin = sin.unsqueeze(-2).to(x.dtype)
    if is_neox_style:
        x1, x2 = torch.chunk(x, 2, dim=-1)
    else:
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
76
77
    o1 = x1 * cos - x2 * sin
    o2 = x2 * cos + x1 * sin
78
79
80
81
    if is_neox_style:
        return torch.cat((o1, o2), dim=-1)
    else:
        return torch.stack((o1, o2), dim=-1).flatten(-2)
82
83


84
85
86
87
88
89
90
91
92
93
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
                      is_neox_style: bool) -> torch.Tensor:
    """
    Args:
        x: [num_tokens, num_heads, head_size]
        cos: [num_tokens, head_size // 2]
        sin: [num_tokens, head_size // 2]
        is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
            positional embeddings.
    """
94
    if current_platform.is_cuda():
95
96
97
98
99
100
        return apply_rotary_emb(x.unsqueeze(0), cos, sin,
                                not is_neox_style).squeeze(0)
    else:
        return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)


101
@CustomOp.register("rotary_embedding")
102
class RotaryEmbedding(CustomOp):
103
104
105
106
107
108
109
    """Original rotary positional embedding."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
110
        base: float,
111
        is_neox_style: bool,
112
        dtype: torch.dtype,
113
114
115
116
117
118
119
    ) -> None:
        super().__init__()
        self.head_size = head_size
        self.rotary_dim = rotary_dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.is_neox_style = is_neox_style
120
        self.dtype = dtype
121
122

        cache = self._compute_cos_sin_cache()
123
        cache = cache.to(dtype)
124
        self.cos_sin_cache: torch.Tensor
125
126
        self.register_buffer("cos_sin_cache", cache, persistent=False)

127
    def _compute_inv_freq(self, base: float) -> torch.Tensor:
128
129
130
131
132
133
        """Compute the inverse frequency."""
        # NOTE(woosuk): To exactly match the HF implementation, we need to
        # use CPU to compute the cache and then move it to GPU. However, we
        # create the cache on GPU for faster initialization. This may cause
        # a slight numerical difference between the HF implementation and ours.
        inv_freq = 1.0 / (base**(torch.arange(
134
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
135
136
137
138
139
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        """Compute the cos and sin cache."""
        inv_freq = self._compute_inv_freq(self.base)
140
        t = torch.arange(self.max_position_embeddings, dtype=torch.float)
141
142
143
144
145
146
147

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

148
    def forward_native(
149
150
151
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
152
        key: Optional[torch.Tensor] = None,
Terry's avatar
Terry committed
153
        offsets: Optional[torch.Tensor] = None,
154
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
155
        """A PyTorch-native implementation of forward()."""
156
157
        if offsets is not None:
            positions = positions + offsets
158
159
160
161
        positions = positions.flatten()
        num_tokens = positions.shape[0]
        cos_sin = self.cos_sin_cache.index_select(0, positions)
        cos, sin = cos_sin.chunk(2, dim=-1)
162
163

        query_shape = query.shape
164
        query = query.view(num_tokens, -1, self.head_size)
165
166
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
167
168
        query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
                                            self.is_neox_style)
169
170
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

171
172
173
174
175
176
177
178
179
        # key may be None in some cases, e.g. cross-layer KV sharing
        if key is not None:
            key_shape = key.shape
            key = key.view(num_tokens, -1, self.head_size)
            key_rot = key[..., :self.rotary_dim]
            key_pass = key[..., self.rotary_dim:]
            key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
                                              self.is_neox_style)
            key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
180
181
        return query, key

182
    def forward_cuda(
183
184
185
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
186
        key: Optional[torch.Tensor] = None,
Terry's avatar
Terry committed
187
        offsets: Optional[torch.Tensor] = None,
188
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
189
190
        from vllm import _custom_ops as ops

191
192
193
194
195
196
197
        # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
        # is expensive, so avoid calling it if possible
        if self.cos_sin_cache.device != query.device or \
            self.cos_sin_cache.dtype != query.dtype:
            self.cos_sin_cache = self.cos_sin_cache.to(query.device,
                                                       dtype=query.dtype)

Antoni Baum's avatar
Antoni Baum committed
198
199
        # ops.rotary_embedding()/batched_rotary_embedding()
        # are in-place operations that update the query and key tensors.
Terry's avatar
Terry committed
200
201
202
203
204
205
        if offsets is not None:
            ops.batched_rotary_embedding(positions, query, key, self.head_size,
                                         self.cos_sin_cache,
                                         self.is_neox_style, self.rotary_dim,
                                         offsets)
        else:
206
207
208
209
210
211
212
213
            ops.rotary_embedding(positions, query, key, self.head_size,
                                 self.cos_sin_cache, self.is_neox_style)
        return query, key

    def forward_xpu(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
214
        key: Optional[torch.Tensor] = None,
215
        offsets: Optional[torch.Tensor] = None,
216
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
217
218
219
220
221
222
        from vllm._ipex_ops import ipex_ops as ops

        self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
                                                   dtype=query.dtype)
        # ops.rotary_embedding()/batched_rotary_embedding()
        # are in-place operations that update the query and key tensors.
223
224
225
226
227
        if key is None:
            # XPU kernel doesn't support key=None so fall back to native impl
            # TODO(sarckk): add support for optional key in
            # ipex.llm.functional.rotary_embedding_batched
            return self.forward_native(positions, query, key, offsets)
228
        else:
229
230
231
232
233
234
235
236
237
            if offsets is not None:
                ops.batched_rotary_embedding(positions, query, key,
                                             self.head_size,
                                             self.cos_sin_cache,
                                             self.is_neox_style,
                                             self.rotary_dim, offsets)
            else:
                ops.rotary_embedding(positions, query, key, self.head_size,
                                     self.cos_sin_cache, self.is_neox_style)
238
239
        return query, key

240
241
242
243
    def forward_neuron(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
244
        key: Optional[torch.Tensor] = None,
245
        offsets: Optional[torch.Tensor] = None,
246
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

        def _apply_rotary_emb_neuron(
            x: torch.Tensor,
            cos: torch.Tensor,
            sin: torch.Tensor,
            is_neox_style: bool,
        ) -> torch.Tensor:
            cos = cos.unsqueeze(-2).to(x.dtype)
            sin = sin.unsqueeze(-2).to(x.dtype)
            if is_neox_style:
                x1, x2 = torch.chunk(x, 2, dim=-1)
            else:
                # x1 = x[..., ::2]

                # x2 = x[..., 1::2]
                d = x.shape[-1] // 2
                x_reshaped = x.view(-1, x.shape[-1])
                x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
                x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
            o1 = x1 * cos - x2 * sin
            o2 = x2 * cos + x1 * sin
            if is_neox_style:
                return torch.cat((o1, o2), dim=-1)
            else:
                return torch.stack((o1, o2), dim=-1).flatten(-2)

        if offsets is not None:
            positions = positions + offsets

        self.cos_sin_cache = self.cos_sin_cache.to(query.device,
                                                   dtype=query.dtype)

        positions = positions.flatten()
        num_tokens = positions.shape[0]
        cos_sin = self.cos_sin_cache.index_select(0, positions)
        cos, sin = cos_sin.chunk(2, dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
286
287
288
        if key is not None:
            key_shape = key.shape
            key = key.view(num_tokens, -1, self.head_size)
289
290
291
292

        if self.rotary_dim == self.head_size:
            query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
            query = query.reshape(query_shape)
293
294
295
            if key is not None:
                key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
                key = key.reshape(key_shape)
296
297
298
299
300
301
302
303
304
305
306
307
        else:
            head_size = query.shape[-1]
            query_reshaped = query.view(-1, head_size)
            query_pass = query_reshaped[:, self.rotary_dim:].view(
                *query.shape[:-1], head_size - self.rotary_dim)
            query_rot = query_reshaped[:, :self.rotary_dim].view(
                *query.shape[:-1], self.rotary_dim)
            query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
                                                 self.is_neox_style)
            query = torch.cat((query_rot, query_pass),
                              dim=-1).reshape(query_shape)

308
309
310
311
312
313
314
315
316
            if key is not None:
                key_reshaped = key.view(-1, head_size)
                key_pass = key_reshaped[:, self.rotary_dim:].view(
                    *key.shape[:-1], head_size - self.rotary_dim)
                key_rot = key_reshaped[:, :self.rotary_dim].view(
                    *key.shape[:-1], self.rotary_dim)
                key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
                                                   self.is_neox_style)
                key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
317
318
        return query, key

319
320
321
322
323
324
    def extra_repr(self) -> str:
        s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
        s += f", max_position_embeddings={self.max_position_embeddings}"
        s += f", base={self.base}, is_neox_style={self.is_neox_style}"
        return s

325
326
327
328

class LinearScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with linear scaling.

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    It supports multiple scaling factors. Since multiple LoRA adapters may have
    different scaling factors, we need multiple cos/sin caches. In this way,
    instead of running rotary embedding kernel per lora, we can run multiple
    lora in a batched way.

    In addition to that, we also keep the cos/sin cache for the scaling factor
    of 1 (default) at all times.

    Exemplary for two scaling factors x=1, y and z with embeddings
    [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
    [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
    [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],

    we construct the cos/sin cache as follows:
    [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
        ...
     [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]

    We then use offsets to index into the cos/sin cache for
    the respective scaling factors.

    The offset to cache can be accessed via `scaling_factor_to_offset` API.

352
353
354
355
356
357
358
359
    Credits to the Reddit user /u/kaiokendev
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
360
        base: float,
361
        is_neox_style: bool,
362
        scaling_factors: Union[list[float], float],
363
        dtype: torch.dtype,
364
    ) -> None:
Terry's avatar
Terry committed
365
366
        if isinstance(scaling_factors, float):
            scaling_factors = [scaling_factors]
367
        self.scaling_factors: list[float] = scaling_factors  # noqa
368
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
369
                         is_neox_style, dtype)
370
        # Lazy initialized.
371
        self._scaling_factor_to_offset: dict[float, int]
372
373
374

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.base)
375
        cache_list: list[torch.Tensor] = []
376
377
        # offsets to the next cache in a tensor.
        # Each offset corresponds to the same index in scaling_factors.
378
        offsets: list[int] = []
Terry's avatar
Terry committed
379
380
381
382
383
384
385
386
387
388
389
390
391
        for scaling_factor in self.scaling_factors:
            # NOTE(woosuk): self.max_position_embeddings is the original
            # maximum length before applying the rope scaling.
            # Thus, the maximum length after applying the rope scaling is
            # self.max_position_embeddings * self.scaling_factor.
            max_len = self.max_position_embeddings * scaling_factor
            t = torch.arange(max_len, dtype=torch.float)
            t = t / scaling_factor

            freqs = torch.einsum("i,j -> ij", t, inv_freq)
            cos = freqs.cos()
            sin = freqs.sin()
            cache = torch.cat((cos, sin), dim=-1)
392
393
394
395
396
397
398
            if not cache_list:
                offset = 0
            else:
                last_offset = offsets[-1]
                next_max_len = cache_list[-1].shape[0]
                offset = last_offset + next_max_len
            offsets.append(offset)
Terry's avatar
Terry committed
399
            cache_list.append(cache)
400
401
402
403
404
        self._scaling_factor_to_offset = {
            float(scaling_factor): offsets[i]
            for i, scaling_factor in enumerate(self.scaling_factors)
        }
        assert len(self.scaling_factors) == len(offsets)
Terry's avatar
Terry committed
405
        return torch.cat(cache_list, dim=0)
406

407
    @property
408
    def scaling_factor_to_offset(self) -> dict[float, int]:
409
410
        return self._scaling_factor_to_offset

411

412
413
414
415
416
417
418
419
class NTKScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with fixed and mixed NTK scaling.
    https://kexue.fm/archives/9706 """

    def __init__(self,
                 head_size: int,
                 rotary_dim: int,
                 max_position_embeddings: int,
420
                 base: float,
421
422
423
424
425
426
427
428
429
                 is_neox_style: bool,
                 scaling_factor: float,
                 dtype: torch.dtype,
                 mixed_b: Optional[float] = None) -> None:
        self.scaling_factor = scaling_factor
        self.mixed_b = mixed_b
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

430
    def _compute_inv_freq(self, base: float) -> torch.Tensor:
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
        inv_freq = super()._compute_inv_freq(base)

        if self.mixed_b is None:
            inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
        else:
            a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
                                                           2)**self.mixed_b
            lambda_1_m = (a * torch.arange(
                1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
            inv_freq = inv_freq / lambda_1_m

        return inv_freq


446
447
448
449
450
451
452
453
454
455
456
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with Dynamic NTK scaling.

    Credits to the Reddit users /u/bloc97 and /u/emozilla
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
457
        base: float,
458
459
        is_neox_style: bool,
        scaling_factor: float,
460
        dtype: torch.dtype,
461
462
463
    ) -> None:
        self.scaling_factor = scaling_factor
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
464
                         is_neox_style, dtype)
465
466
467
468
469
470
471
472
473
474
475
476

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        # NOTE(woosuk): self.max_position_embeddings is the original
        # maximum length before applying the rope scaling.
        # Thus, the maximum length after applying the rope scaling is
        # self.max_position_embeddings * self.scaling_factor.
        max_len = self.max_position_embeddings * self.scaling_factor
        base = self.base * (
            (self.scaling_factor * max_len / self.max_position_embeddings) -
            (self.scaling_factor - 1))**(self.rotary_dim /
                                         (self.rotary_dim - 2))
        inv_freq = self._compute_inv_freq(base)
477
        t = torch.arange(max_len, dtype=torch.float)
478
479
480
481
482
483

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache
Antoni Baum's avatar
Antoni Baum committed
484
485


486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with Dynamic NTK alpha.

    Based on the original RotaryEmbedding implementation.
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_alpha: float,
        dtype: torch.dtype,
    ) -> None:
        self.scaling_alpha = scaling_alpha
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        # For Hunyuan DynamicNTKAlphaRotaryEmbedding
        max_len = self.max_position_embeddings
        base = self.base * self.scaling_alpha**(self.rotary_dim /
                                                (self.rotary_dim - 2))
        inv_freq = self._compute_inv_freq(base)
        t = torch.arange(max_len, dtype=torch.float)

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache


Antoni Baum's avatar
Antoni Baum committed
521
522
523
524
525
526
527
528
529
530
531
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
                              dim: int,
                              base: float = 10000,
                              max_position_embeddings: int = 2048) -> float:
    return (dim * math.log(max_position_embeddings /
                           (num_rotations * 2 * math.pi))) / (2 *
                                                              math.log(base))


# Find dim range bounds based on rotations
532
533
534
535
536
def _yarn_find_correction_range(
        low_rot: int,
        high_rot: int,
        dim: int,
        base: float = 10000,
537
        max_position_embeddings: int = 2048) -> tuple[int, int]:
Antoni Baum's avatar
Antoni Baum committed
538
539
540
541
542
543
544
545
546
    low = math.floor(
        _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
    high = math.ceil(
        _yarn_find_correction_dim(high_rot, dim, base,
                                  max_position_embeddings))
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case


def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
547
                           dtype: torch.dtype) -> torch.Tensor:
Antoni Baum's avatar
Antoni Baum committed
548
549
550
    if low == high:
        high += 0.001  # Prevent singularity

551
    linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
Antoni Baum's avatar
Antoni Baum committed
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func


def _yarn_get_mscale(scale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0


class YaRNScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with YaRN method.

    Credits to Peng et al. github.com/jquesnelle/yarn
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
573
        base: float,
Antoni Baum's avatar
Antoni Baum committed
574
575
        is_neox_style: bool,
        scaling_factor: float,
576
        dtype: torch.dtype,
Antoni Baum's avatar
Antoni Baum committed
577
578
579
        *,
        extrapolation_factor: float = 1,
        attn_factor: float = 1,
580
581
        beta_fast: int = 32,
        beta_slow: int = 1,
Antoni Baum's avatar
Antoni Baum committed
582
583
584
585
586
587
588
589
590
591
    ) -> None:
        self.scaling_factor = scaling_factor
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        # Get n-d magnitude scaling corrected for interpolation
        self.mscale = float(
            _yarn_get_mscale(self.scaling_factor) * attn_factor)
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
592
                         is_neox_style, dtype)
Antoni Baum's avatar
Antoni Baum committed
593
594

    def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
595
596
597
        pos_freqs = self.base**(
            torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
            self.rotary_dim)
Antoni Baum's avatar
Antoni Baum committed
598
599
600
601
602
603
604
605
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                                self.rotary_dim, self.base,
                                                self.max_position_embeddings)
        # Get n-d rotational scaling corrected for extrapolation
        inv_freq_mask = (1 - _yarn_linear_ramp_mask(
606
607
            low, high, self.rotary_dim // 2,
            dtype=torch.float)) * self.extrapolation_factor
Antoni Baum's avatar
Antoni Baum committed
608
609
610
611
612
613
614
615
616
617
618
619
620
        inv_freq = inv_freq_interpolation * (
            1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.scaling_factor)
        t = torch.arange(self.max_position_embeddings * self.scaling_factor,
                         dtype=torch.float32)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = (freqs.cos() * self.mscale)
        sin = (freqs.sin() * self.mscale)
        cache = torch.cat((cos, sin), dim=-1)
        return cache
621
622


623
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
624
625
626
627
628
629
630
631
632
633
634
    """Phi3 family of models scaled rotary embedding.

    Based on the original RotaryEmbedding implementation.
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        original_max_position_embeddings: int,
635
        base: float,
636
        is_neox_style: bool,
637
        dtype: torch.dtype,
638
639
        short_factor: list[float],
        long_factor: list[float],
640
641
        short_mscale: Optional[float] = None,
        long_mscale: Optional[float] = None,
642
643
644
645
646
    ):
        super().__init__()

        if is_neox_style is False:
            raise ValueError(
647
648
                "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
            )
649

Amit Garg's avatar
Amit Garg committed
650
        self.rotary_dim = rotary_dim
651
652
653
654
655
656
        self.head_size = head_size
        self.max_position_embeddings = max_position_embeddings
        self.original_max_position_embeddings = original_max_position_embeddings
        self.base = base
        self.short_factor = short_factor
        self.long_factor = long_factor
657

658
        scale = self.max_position_embeddings / \
659
                self.original_max_position_embeddings
660
        if scale <= 1.0:
661
            scaling_factor = 1.0
662
        else:
663
            scaling_factor = math.sqrt(
664
665
                1 + math.log(scale) /
                math.log(self.original_max_position_embeddings))
666
667
668
669
670
671
672
        if short_mscale is None:
            short_mscale = scaling_factor
        if long_mscale is None:
            long_mscale = scaling_factor

        self.short_mscale = short_mscale
        self.long_mscale = long_mscale
673

674
675
        short_cache = self._compute_cos_sin_cache(
            original_max_position_embeddings, short_factor, short_mscale)
676
        short_cache = short_cache.to(dtype)
677
678
679

        long_cache = self._compute_cos_sin_cache(max_position_embeddings,
                                                 long_factor, long_mscale)
680
        long_cache = long_cache.to(dtype)
681

682
        long_short_cache = torch.cat([short_cache, long_cache], dim=0)
683
684
685
686
        self.register_buffer("long_short_cos_sin_cache",
                             long_short_cache,
                             persistent=False)

687
    def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
688
689
        rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
        inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
Amit Garg's avatar
Amit Garg committed
690
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
691
692
693
694
695
        return inv_freq

    def _compute_cos_sin_cache(
        self,
        max_position_embeddings: int,
696
        rescale_factors: list[float],
697
698
699
700
701
        mscale: float,
    ) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(rescale_factors)
        t = torch.arange(max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
702
703
        cos = freqs.cos() * mscale
        sin = freqs.sin() * mscale
704
705
        cache = torch.cat((cos, sin), dim=-1)
        return cache
706
    
707
708
709
710
    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
711
        key: Optional[torch.Tensor] = None,
712
        offsets: Optional[torch.Tensor] = None,
713
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
714
        assert key is not None
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        query = query.view(*query.shape[:-1], -1, self.head_size)
        key = key.view(*key.shape[:-1], -1, self.head_size)

        k = self.original_max_position_embeddings
        long_prompt_offset = (torch.any(positions > k).float() *
                              torch.full_like(positions, k)).long()
        idx = (torch.add(positions, long_prompt_offset)
               if long_prompt_offset is not None else positions)
        idx = torch.add(idx, offsets) if offsets is not None else idx
        cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

        cos, sin = cos_sin.chunk(2, dim=-1)
        cos = cos.repeat(1, 2).unsqueeze(-2)
        sin = sin.repeat(1, 2).unsqueeze(-2)

Amit Garg's avatar
Amit Garg committed
730
731
732
733
734
735
736
737
738
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
        query = torch.cat((query_rot, query_pass), dim=-1)

        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
        key = torch.cat((key_rot, key_pass), dim=-1)
739
740
741
742

        return query.flatten(-2), key.flatten(-2)


wangding zeng's avatar
wangding zeng committed
743
744
745
746
747
748
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
@triton.jit
def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q, stride1: int,
                                            stride2: int, stride_cs: int,
                                            dim1: int, dim2: int, dim3: int,
                                            BLOCK_SIZE: tl.constexpr):
    pid0 = tl.program_id(0)
    pid1 = tl.program_id(1)
    pid2 = tl.program_id(2)
    offsets_cs = tl.arange(0, BLOCK_SIZE) + pid2 * BLOCK_SIZE
    offsets_q = tl.arange(0, BLOCK_SIZE * 2) + pid2 * BLOCK_SIZE * 2

    offsets = pid0 * stride1 + pid1 * stride2 + offsets_q
    mask = offsets_cs < dim3
    mask2 = offsets_q < dim3 * 2

    v_cos = tl.load(cos_sin + pid0 * stride_cs + offsets_cs, mask=mask)
    v_cos2 = tl.interleave(v_cos, v_cos)
    v_sin = tl.load(cos_sin + pid0 * stride_cs + dim3 + offsets_cs, mask=mask)
    v_sin2 = tl.interleave(v_sin, v_sin)
    x12 = tl.load(q + offsets, mask=mask2)
    x1, x2 = tl.split(x12.reshape([BLOCK_SIZE, 2]))
    # we are both reading and writing 'q'; make sure all warps are in sync
    tl.debug_barrier()
    x12_ = tl.ravel(tl.join(-x2, x1))
    x12 = x12 * v_cos2 + x12_ * v_sin2
    tl.store(q + offsets, x12, mask=mask2)


wangding zeng's avatar
wangding zeng committed
777
778
779
780
781
782
783
784
785
786
787
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with YaRN method.

    Credits to Peng et al. github.com/jquesnelle/yarn
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
788
        base: float,
wangding zeng's avatar
wangding zeng committed
789
790
791
792
793
794
795
796
797
798
        is_neox_style: bool,
        scaling_factor: float,
        dtype: torch.dtype,
        *,
        extrapolation_factor: float = 1,
        attn_factor: float = 1,
        beta_fast: int = 32,
        beta_slow: int = 1,
        mscale: float = 1,
        mscale_all_dim: float = 0,
799
        reference: bool = False,
wangding zeng's avatar
wangding zeng committed
800
801
802
803
804
805
    ) -> None:
        self.scaling_factor = scaling_factor
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
806
        self.reference = reference
wangding zeng's avatar
wangding zeng committed
807
808
809
810
811
812
813
814
815
        # Get n-d magnitude scaling corrected for interpolation.
        self.mscale = float(
            yarn_get_mscale(self.scaling_factor, float(mscale)) /
            yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
            attn_factor)
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
816
817
818
819
820
821
822
        pos_freqs = self.base**(
            torch.arange(0,
                         self.rotary_dim,
                         2,
                         dtype=torch.float,
                         device=current_platform.device_type) /
            self.rotary_dim)
wangding zeng's avatar
wangding zeng committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                                self.rotary_dim, self.base,
                                                self.max_position_embeddings)
        # Get n-d rotational scaling corrected for extrapolation
        inv_freq_mask = (1 - _yarn_linear_ramp_mask(
            low, high, self.rotary_dim // 2,
            dtype=torch.float)) * self.extrapolation_factor
        inv_freq = inv_freq_interpolation * (
            1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.scaling_factor)
        t = torch.arange(self.max_position_embeddings * self.scaling_factor,
840
                         device=current_platform.device_type,
wangding zeng's avatar
wangding zeng committed
841
842
843
844
845
846
                         dtype=torch.float32)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = (freqs.cos() * self.mscale)
        sin = (freqs.sin() * self.mscale)
        cache = torch.cat((cos, sin), dim=-1)
        return cache
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
    
    def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
                                   head_size: int, cos_sin_cache: torch.Tensor, 
                                   is_neox_style: bool) -> None:
        from lightop import op
        op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)

    def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
                                            head_size: int, cos_sin_cache: torch.Tensor, 
                                            is_neox_style: bool) -> None:
        pass

    direct_register_custom_op(
        op_name="rotary_embedding_deepseek_fuse",
        op_func=rotary_embedding_deepseek_fuse,
        mutates_args=[], 
        fake_impl=rotary_embedding_deepseek_fuse_fake,
    )
wangding zeng's avatar
wangding zeng committed
865
866
867
868
869

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
870
        key: Optional[torch.Tensor] = None,
wangding zeng's avatar
wangding zeng committed
871
        offsets: Optional[torch.Tensor] = None,
872
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
wangding zeng's avatar
wangding zeng committed
873
        """PyTorch-native implementation equivalent to forward()."""
874
        assert key is not None
wangding zeng's avatar
wangding zeng committed
875

876
877
878
        if self.cos_sin_cache.device != positions.device:
            self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
                positions.device)
wangding zeng's avatar
wangding zeng committed
879
880
        cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
                                     if offsets is not None else positions]
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        if query.device.type == 'cuda' and not self.is_neox_style \
            and not self.reference:
            assert len(query.shape) == 3

            def call(q):
                BLOCK_SIZE = 64
                grid = (
                    q.shape[-3],
                    q.shape[-2],
                    triton.cdiv(self.rotary_dim // 2, BLOCK_SIZE),
                )
                deepseek_scaling_rotary_emb_kernel_gptj[grid](
                    cos_sin,
                    q,
                    stride1=q.stride()[-3],
                    stride2=q.stride()[-2],
                    stride_cs=cos_sin.stride()[-2],
                    dim1=q.shape[0],
                    dim2=q.shape[1],
                    dim3=self.rotary_dim // 2,
                    BLOCK_SIZE=BLOCK_SIZE,
                    num_warps=1)

904
905
906
907
908
909
            # if envs.VLLM_USE_LIGHTOP:
            if False:
                torch.ops.vllm.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
            else:
                call(query)
                call(key)
910
            return query, key
wangding zeng's avatar
wangding zeng committed
911
        else:
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
            query_rot = query[..., :self.rotary_dim]
            key_rot = key[..., :self.rotary_dim]
            if self.rotary_dim < self.head_size:
                query_pass = query[..., self.rotary_dim:]
                key_pass = key[..., self.rotary_dim:]

            cos, sin = cos_sin.chunk(2, dim=-1)
            if self.is_neox_style:
                # NOTE(woosuk): Here we assume that the positions tensor has the
                # shape [batch_size, seq_len].
                cos = cos.repeat(1, 1, 2).unsqueeze(-2)
                sin = sin.repeat(1, 1, 2).unsqueeze(-2)
            else:
                cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
                sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

            rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
            query_rot = query_rot * cos + rotate_fn(query_rot) * sin
            key_rot = key_rot * cos + rotate_fn(key_rot) * sin
wangding zeng's avatar
wangding zeng committed
931
932
933
934
935
936
937
938
939
940
941


        if self.rotary_dim < self.head_size:
            query = torch.cat((query_rot, query_pass), dim=-1)
            key = torch.cat((key_rot, key_pass), dim=-1)
        else:
            query = query_rot
            key = key_rot
        return query, key


942
943
944
945
946
947
948
class Llama3RotaryEmbedding(RotaryEmbedding):

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
949
        base: float,
950
951
952
953
954
955
956
957
958
959
960
961
962
        is_neox_style: bool,
        dtype: torch.dtype,
        scaling_factor: float,
        low_freq_factor: float,
        high_freq_factor: float,
        orig_max_position: int,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.low_freq_factor = low_freq_factor
        self.high_freq_factor = high_freq_factor
        self.orig_max_position = orig_max_position
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)
963

964
    def _compute_inv_freq(self, base: float) -> torch.Tensor:
965
        inv_freqs = super()._compute_inv_freq(base)
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
        low_freq_wavelen = self.orig_max_position / self.low_freq_factor
        high_freq_wavelen = self.orig_max_position / self.high_freq_factor

        wave_len = 2 * math.pi / inv_freqs
        if self.low_freq_factor != self.high_freq_factor:
            smooth = (self.orig_max_position / wave_len - self.low_freq_factor
                      ) / (self.high_freq_factor - self.low_freq_factor)
        else:
            smooth = 0
        new_freqs = torch.where(
            wave_len < high_freq_wavelen,
            inv_freqs,
            torch.where(
                wave_len > low_freq_wavelen,
                inv_freqs / self.scaling_factor,
                (1 - smooth) * inv_freqs / self.scaling_factor +
                smooth * inv_freqs,
            ),
        )
        return new_freqs
986
987


988
989
990
991
992
993
994
class Llama4VisionRotaryEmbedding(RotaryEmbedding):

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
995
        base: float,
996
997
998
999
1000
1001
        is_neox_style: bool,
        dtype: torch.dtype,
    ):
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

1002
    def _compute_inv_freq(self, base: float) -> torch.Tensor:
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        inv_freqs = super()._compute_inv_freq(base)
        inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
        return inv_freqs

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.base)

        # self.max_position_embeddings here is number of image patches
        # i.e. (image_size // patch_size) ** 2
        num_patches = self.max_position_embeddings
        img_idx = torch.arange(num_patches,
                    dtype=torch.int32) \
                    .reshape(num_patches, 1)
        img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
        img_idx[-1, -1] = -2  # set to ID_CLS_TOKEN
        num_patches_single_dim = int(math.sqrt(num_patches))
        frequencies_x = img_idx % num_patches_single_dim
        frequencies_y = img_idx // num_patches_single_dim
        freqs_x = ((frequencies_x + 1)[..., None] *
                   inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs_y = ((frequencies_y + 1)[..., None] *
                   inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs = torch.cat([freqs_x, freqs_y],
                          dim=-1).float().contiguous()[..., ::2]
        freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
        cache = torch.view_as_complex(
            torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
        return cache

    def forward(
        self,
        query: torch.Tensor,
1035
        key: Optional[torch.Tensor] = None,
1036
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
1037
        assert key is not None
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
        query_ = torch.view_as_complex(query.float().reshape(
            *query.shape[:-1], -1, 2))
        key_ = torch.view_as_complex(key.float().reshape(
            *key.shape[:-1], -1, 2))
        broadcast_shape = [
            d if i == 1 or i == (query_.ndim - 1) else 1
            for i, d in enumerate(query_.shape)
        ]
        freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
        query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
        key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
        return query_out.type_as(query), key_out.type_as(key)


1053
1054
1055
1056
1057
1058
1059
1060
class MRotaryEmbedding(RotaryEmbedding):
    """Rotary Embedding with Multimodal Sections."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
1061
        base: float,
1062
1063
        is_neox_style: bool,
        dtype: torch.dtype,
1064
        mrope_section: Optional[list[int]] = None,
1065
    ) -> None:
Roger Wang's avatar
Roger Wang committed
1066
1067
1068
1069
1070
1071
        # In Qwen2.5-VL, the maximum index value is related to the duration of
        # the input video. We enlarge max_position_embeddings to 4 times to get
        # a larger the cos and sin cache.
        self.cache_max_position_num = max_position_embeddings * 4
        super().__init__(head_size, rotary_dim, self.cache_max_position_num,
                         base, is_neox_style, dtype)
1072
1073
1074
1075
1076
1077
1078
1079
1080

        self.mrope_section = mrope_section
        if self.mrope_section:
            assert sum(self.mrope_section) == rotary_dim // 2

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
1081
        key: Optional[torch.Tensor] = None,
1082
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
        """PyTorch-native implementation equivalent to forward().

        Args:
            positions:
                [num_tokens,] (text only) or
                [3, num_tokens] (T/H/W positions with multimodal inputs)
            query: [num_tokens, num_heads * head_size]
            key: [num_tokens, num_kv_heads * head_size]
        """
        assert positions.ndim == 1 or positions.ndim == 2
1093
        assert key is not None
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126

        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        if positions.ndim == 2:
            assert self.mrope_section

            cos = torch.cat([
                m[i]
                for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
            ],
                            dim=-1)
            sin = torch.cat([
                m[i]
                for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
            ],
                            dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

1127
    @classmethod
1128
    def get_input_positions(
1129
        cls,
1130
        input_tokens: list[int],
Roger Wang's avatar
Roger Wang committed
1131
        hf_config: PretrainedConfig,
1132
1133
1134
        image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        second_per_grid_ts: Optional[list[float]],
1135
        context_len: int = 0,
1136
        seq_len: Optional[int] = None,
1137
1138
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
1139
    ) -> tuple[list[list[int]], int]:
1140
1141
        """Get mrope input positions and delta value."""

1142
1143
1144
1145
1146
        image_grid_thw = [] if image_grid_thw is None else image_grid_thw
        video_grid_thw = [] if video_grid_thw is None else video_grid_thw
        second_per_grid_ts = [] if second_per_grid_ts is None else \
            second_per_grid_ts

1147
        llm_positions, mrope_position_delta = \
1148
            cls.get_input_positions_tensor(
Roger Wang's avatar
Roger Wang committed
1149
1150
1151
1152
1153
1154
1155
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
1156
1157
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
1158
1159
1160
1161
            )

        return llm_positions.tolist(), mrope_position_delta

1162
    @classmethod
1163
    def get_input_positions_tensor(
1164
        cls,
1165
        input_tokens: list[int],
1166
        hf_config: PretrainedConfig,
1167
1168
1169
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: list[float],
1170
1171
1172
1173
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
1174
    ) -> tuple[torch.Tensor, int]:
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
        from vllm.transformers_utils.config import thinker_uses_mrope
        if thinker_uses_mrope(hf_config):
            return cls._omni_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )
1188
1189
1190
1191
1192
1193
1194
1195
1196
        elif "glm4v" in hf_config.model_type:
            return cls._glm4v_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                context_len=context_len,
                seq_len=seq_len,
            )
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        else:
            return cls._vl_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
            )

1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
    @classmethod
    def _glm4v_get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        context_len: int = 0,
        seq_len: Optional[int] = None,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value for GLM4V."""

        image_token_id = hf_config.image_token_id
        video_start_token_id = hf_config.video_start_token_id
        video_end_token_id = hf_config.video_end_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        llm_pos_ids_list: list = []

        if not (image_grid_thw is None and video_grid_thw is None):
            if isinstance(image_grid_thw, torch.Tensor):
                image_grid_thw = image_grid_thw.tolist()

            input_token_type: list[str] = []
            video_check_flg = False
            for token in input_tokens:
                if token == video_start_token_id:
                    video_check_flg = True
                elif token == video_end_token_id:
                    video_check_flg = False

                if (token == image_token_id) and (video_check_flg is False):
                    input_token_type.append("image")
                elif (token == image_token_id) and (video_check_flg is True):
                    input_token_type.append("video")
                else:
                    input_token_type.append("text")

            input_type_group: list[tuple[str, int, int]] = []
            for key, group_iter in itertools.groupby(
                    enumerate(input_token_type), lambda x: x[1]):
                group_list = list(group_iter)
                start_index = group_list[0][0]
                end_index = group_list[-1][0] + 1
                input_type_group.append((key, start_index, end_index))

            video_frame_num = 1
            mm_data_idx = 0
            for modality_type, start_idx, end_idx in input_type_group:
                st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                    llm_pos_ids_list) > 0 else 0
                if modality_type == "image":
                    t, h, w = (
                        image_grid_thw[mm_data_idx][0],
                        image_grid_thw[mm_data_idx][1],
                        image_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = \
                        t, h // spatial_merge_size, w // spatial_merge_size

                    t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
                        -1, llm_grid_h * llm_grid_w).flatten()
                    h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                        llm_grid_t, -1, llm_grid_w).flatten()
                    w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                        llm_grid_t, llm_grid_h, -1).flatten()
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx)
                    mm_data_idx += 1

                elif modality_type == "video":
                    t, h, w = (
                        video_frame_num,
                        image_grid_thw[mm_data_idx][1],
                        image_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = \
                        t, h // spatial_merge_size, w // spatial_merge_size

                    for t_idx in range(llm_grid_t):
                        t_index = torch.tensor(t_idx).view(-1, 1).expand(
                            -1, llm_grid_h * llm_grid_w).flatten()
                        h_index = torch.arange(llm_grid_h).view(
                            1, -1, 1).expand(1, -1, llm_grid_w).flatten()
                        w_index = torch.arange(llm_grid_w).view(
                            1, 1, -1).expand(1, llm_grid_h, -1).flatten()
                        llm_pos_ids_list.append(
                            torch.stack([t_index, h_index, w_index]) + st_idx)

                    mm_data_idx += 1
                    video_frame_num += 1

                else:
                    text_len = end_idx - start_idx
                    llm_pos_ids_list.append(
                        torch.arange(text_len).view(1, -1).expand(3, -1) +
                        st_idx)
                    video_frame_num = 1

        else:
            text_len = len(input_tokens)
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1))

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        llm_positions = llm_positions[:, context_len:seq_len]
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
        return llm_positions, mrope_position_delta

1317
1318
1319
    @classmethod
    def _vl_get_input_positions_tensor(
        cls,
1320
        input_tokens: list[int],
Roger Wang's avatar
Roger Wang committed
1321
        hf_config: PretrainedConfig,
1322
1323
1324
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: list[float],
1325
1326
        context_len: int = 0,
        seq_len: Optional[int] = None,
1327
    ) -> tuple[torch.Tensor, int]:
1328
1329
        """Get mrope input positions and delta value."""

Roger Wang's avatar
Roger Wang committed
1330
1331
1332
1333
1334
1335
1336
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(hf_config.vision_config,
                                    "tokens_per_second", 1.0)

1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
            input_tokens_tensor == vision_start_token_id).squeeze(1)
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
Roger Wang's avatar
Roger Wang committed
1350
            video_second_per_grid_t = 0.0
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
            if image_token_id in input_tokens and remain_images > 0:
                ed_image = input_tokens.index(image_token_id, st)
            else:
                ed_image = len(input_tokens) + 1
            if video_token_id in input_tokens and remain_videos > 0:
                ed_video = input_tokens.index(video_token_id, st)
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
Roger Wang's avatar
Roger Wang committed
1374
                video_second_per_grid_t = 1.0
1375
                if second_per_grid_ts:
Roger Wang's avatar
Roger Wang committed
1376
                    video_second_per_grid_t = second_per_grid_ts[video_index]
1377
1378
1379
                video_index += 1
                remain_videos -= 1
                ed = ed_video
Roger Wang's avatar
Roger Wang committed
1380

1381
1382
1383
1384
1385
1386
1387
1388
1389
            llm_grid_t, llm_grid_h, llm_grid_w = \
                t, h // spatial_merge_size, w // spatial_merge_size
            text_len = ed - st

            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

Roger Wang's avatar
Roger Wang committed
1390
1391
1392
1393
            t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
                -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
                       tokens_per_second).long().flatten()

1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                llm_grid_t, -1, llm_grid_w).flatten()
            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                llm_grid_t, llm_grid_h, -1).flatten()
            llm_pos_ids_list.append(
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
1412
        llm_positions = llm_positions[:, context_len:seq_len]
1413

1414
        return llm_positions, mrope_position_delta
1415

1416
1417
1418
    @classmethod
    def _omni_get_input_positions_tensor(
        cls,
1419
        input_tokens: list[int],
1420
        hf_config: PretrainedConfig,
1421
1422
1423
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: Optional[list[float]] = None,
1424
1425
1426
1427
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
1428
    ) -> tuple[torch.Tensor, int]:
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
        """Get mrope input positions and delta value (Qwen2.5-Omni version).

        Differences from MRotaryEmbedding:
            1. Add audio support (and related `audio_feature_lengths`).
            2. Add `use_audio_in_video` option to read audio from video inputs.
                In this case, audio and vision position ids will be split into
                chunks and interleaved.

        Example:

            (V_i are vision position ids, A_i are audio position ids)

            |V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
            |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
        """

        # TODO(fyabc): refactor and share more code with
        #  _vl_get_input_positions_tensor.

        thinker_config = hf_config.thinker_config
        audio_token_id = thinker_config.audio_token_index
        image_token_id = thinker_config.image_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
1454
        vision_start_token_id = thinker_config.vision_start_token_id
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
        vision_end_token_id = thinker_config.vision_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(thinker_config.vision_config,
                                    "tokens_per_second", 25)

        if isinstance(image_grid_thw, list):
            image_grid_thw = torch.tensor(image_grid_thw)
        if isinstance(video_grid_thw, list):
            video_grid_thw = torch.tensor(video_grid_thw)

        src_item = input_tokens
        audio_seqlens = audio_feature_lengths
        if not second_per_grid_ts:
            second_per_grid_ts = [1] * video_grid_thw.shape[0]
        audio_idx = 0
        video_idx = 0
        image_idx = 0
        new_src_item: list[int] = []
        llm_pos_ids_list: list[torch.Tensor] = []

        idx = 0
        while idx < len(src_item):
            new_src_item_len = len(new_src_item)
            start_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            if src_item[idx] not in [
                    audio_token_id, video_token_id, image_token_id
            ]:
1484
1485
1486
1487
1488
1489
1490
1491
1492
                if use_audio_in_video and idx > 0:
                    if src_item[idx] == vision_end_token_id and \
                        src_item[idx - 1] == audio_end_token_id:
                        # processing the <|audio_eos|> before <|vision_eos|>
                        start_idx -= 1
                    elif src_item[idx] == audio_start_token_id and \
                        src_item[idx - 1] == vision_start_token_id:
                        # processing the <|audio_bos|> after <|vision_eos|>
                        start_idx -= 1
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
                new_src_item.append(src_item[idx])
                llm_pos_ids = torch.tensor([start_idx],
                                           dtype=torch.long).expand(3, -1)
                llm_pos_ids_list.append(llm_pos_ids)
            elif src_item[idx] == audio_token_id:
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
                new_src_item.extend([audio_token_id] * place_num)
                llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
                llm_pos_ids_list.append(llm_pos_ids)
                audio_idx += 1
            elif src_item[idx] == image_token_id:
                grid_t = image_grid_thw[image_idx][0]
                grid_hs = image_grid_thw[:, 1]
                grid_ws = image_grid_thw[:, 2]
                t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
                llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                    start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
                    grid_ws)
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = image_grid_thw[image_idx].prod() // (
                    spatial_merge_size**2)
                new_src_item.extend([image_token_id] * vision_seqlen)
                image_idx += 1
            elif src_item[idx] == video_token_id and not use_audio_in_video:
                grid_t = video_grid_thw[video_idx][0]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_index = (torch.arange(grid_t) *
                           second_per_grid_ts[video_idx] *
                           tokens_per_second).long()
                llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                    start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
                    grid_ws)
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2)
                new_src_item.extend([video_token_id] * vision_seqlen)
                video_idx += 1
            else:
                # read audio from video
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2)
                grid_t = video_grid_thw[video_idx][0]
                grid_h = video_grid_thw[video_idx][1]
                grid_w = video_grid_thw[video_idx][2]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
                t_index = (torch.arange(grid_t) *
                           second_per_grid_ts[video_idx] *
                           tokens_per_second).long()
                t_index_split_chunk = cls._split_list_into_ranges(
                    t_index, t_ntoken_per_chunk)
                place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
                pure_audio_len = place_num - 2
                added_audio_len = 0
1553
                audio_llm_pos_ids_list: list[torch.Tensor] = []
1554
1555
1556
1557
1558
1559
                for t_chunk in t_index_split_chunk:
                    vision_ntoken_per_chunk = len(
                        t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
                    new_src_item.extend([video_token_id] *
                                        vision_ntoken_per_chunk)
                    vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
1560
                        start_idx, video_idx, spatial_merge_size, t_chunk,
1561
1562
1563
1564
1565
1566
1567
                        grid_hs, grid_ws).split(1, dim=1)
                    llm_pos_ids_list.extend(vision_llm_pos_ids_list)
                    new_src_item.extend(
                        min(t_ntoken_per_chunk, pure_audio_len -
                            added_audio_len) * [audio_token_id])
                    audio_start_idx = start_idx if len(
                        audio_llm_pos_ids_list
1568
                    ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
1569
1570
1571
1572
1573
                    if min(t_ntoken_per_chunk,
                           pure_audio_len - added_audio_len) > 0:
                        audio_llm_pos_ids_list = (torch.arange(
                            min(t_ntoken_per_chunk, pure_audio_len -
                                added_audio_len)).expand(3, -1) +
1574
1575
                                                  audio_start_idx).split(1,
                                                                         dim=1)
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
                    else:
                        audio_llm_pos_ids_list = []
                    added_audio_len += min(t_ntoken_per_chunk,
                                           pure_audio_len - added_audio_len)
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                if added_audio_len < pure_audio_len:
                    new_src_item.extend(
                        (pure_audio_len - added_audio_len) * [audio_token_id])
                    audio_llm_pos_ids_list = (
                        torch.arange(pure_audio_len - added_audio_len).expand(
                            3, -1) + llm_pos_ids_list[-1].max() + 1).split(
                                1, dim=1)
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                audio_idx += 1
                video_idx += 1
            # move to the next token
            idx += len(new_src_item) - new_src_item_len

        llm_positions = torch.cat(llm_pos_ids_list, dim=1)
        mrope_position_delta = torch.cat(llm_pos_ids_list,
                                         dim=1).max() + 1 - len(src_item)
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

    @staticmethod
    def _get_llm_pos_ids_for_vision(
        start_idx: int,
        vision_idx: int,
        spatial_merge_size: int,
1606
        t_index: list[int],
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
        grid_hs: torch.Tensor,
        grid_ws: torch.Tensor,
    ) -> torch.Tensor:
        llm_pos_ids_list = []
        llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
        llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
        h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
            len(t_index), -1, llm_grid_w).flatten())
        w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
            len(t_index), llm_grid_h, -1).flatten())
        t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
            -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
        _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
        llm_pos_ids_list.append(_llm_pos_ids + start_idx)
        llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
        return llm_pos_ids

    @staticmethod
    def _split_list_into_ranges(lst: torch.Tensor,
1626
1627
                                interval: int) -> list[list[int]]:
        ranges: list[list[int]] = [[]
1628
1629
1630
1631
1632
1633
                                   for _ in range((max(lst) // interval) + 1)]
        for num in lst:
            index = num // interval
            ranges[index].append(num)
        return ranges

1634
1635
1636
1637
1638
    @staticmethod
    def get_next_input_positions(
        mrope_position_delta: int,
        context_len: int,
        seq_len: int,
1639
    ) -> list[list[int]]:
1640
1641
1642
1643
1644
1645
        return [
            list(
                range(context_len + mrope_position_delta,
                      seq_len + mrope_position_delta)) for _ in range(3)
        ]

1646
    @staticmethod
1647
1648
1649
1650
1651
1652
1653
1654
    def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
                                        mrope_position_delta: int,
                                        context_len: int, num_new_tokens: int):

        values = np.arange(mrope_position_delta + context_len,
                           mrope_position_delta + context_len + num_new_tokens,
                           dtype=out.dtype)
        out[:, out_offset:out_offset + num_new_tokens] = values
1655

1656
1657
1658
1659
1660
    @classmethod
    def omni_get_updates_use_audio_in_video(
        cls,
        thinker_config: PretrainedConfig,
        audio_len: int,
1661
        video_grid_thw: Union[list[int], torch.Tensor],
1662
        video_second_per_grid_t: float,
1663
    ) -> list[int]:
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
        """Get video prompt updates when `use_audio_in_video` is True.

        In this case, audio and vision update ids will be split into
        chunks and interleaved (details in `_omni_get_input_positions_tensor`).

        <|video_bos|><|VIDEO|><|video_eos|> =>
        <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
        """

        audio_token_id = thinker_config.audio_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(thinker_config.vision_config,
                                    "tokens_per_second", 25)

        grid_t = video_grid_thw[0]
        grid_h = video_grid_thw[1]
        grid_w = video_grid_thw[2]
        t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
        t_index = (torch.arange(grid_t) * video_second_per_grid_t *
                   tokens_per_second).long()
        t_index_split_chunk = cls._split_list_into_ranges(
            t_index, t_ntoken_per_chunk)

        updates = [audio_start_token_id]
        added_audio_len = 0
        for t_chunk in t_index_split_chunk:
            vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
                spatial_merge_size**2)
            updates.extend([video_token_id] * vision_ntoken_per_chunk)

            audio_chunk_size = min(t_ntoken_per_chunk,
                                   audio_len - added_audio_len)
            updates.extend(audio_chunk_size * [audio_token_id])
            added_audio_len += audio_chunk_size
        if added_audio_len < audio_len:
            updates.extend((audio_len - added_audio_len) * [audio_token_id])
        updates.extend([audio_end_token_id])

        return updates

1708

1709
1710
1711
1712
1713
1714
1715
1716
1717
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
    """Rotary positional embedding for Dual Chunk Attention."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
1718
        base: float,
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
        is_neox_style: bool,
        dtype: torch.dtype,
        chunk_size: int,
        local_size: int,
    ) -> None:
        super().__init__()
        self.head_size = head_size
        self.rotary_dim = rotary_dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.is_neox_style = is_neox_style
        self.chunk_size = chunk_size
        self.local_size = local_size
        self.dtype = dtype
        self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
        (q_cache, qc_cache, k_cache, qc_no_clamp_cache,
         q_inter_cache) = self._compute_cos_sin_cache()

        self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
        self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
        self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
        self.register_buffer("cos_sin_qc_no_clamp_cache",
                             qc_no_clamp_cache,
                             persistent=False)
        self.register_buffer("cos_sin_q_inter_cache",
                             q_inter_cache,
                             persistent=False)

1747
    def _compute_inv_freq(self, base: float) -> torch.Tensor:
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
        """Compute the inverse frequency."""
        # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
        # However, we use `torch.arange(..., dtype=torch.float)` instead to
        # avoid numerical issues with large base values (e.g., 10000000).
        # This may cause a slight numerical difference between the HF
        # implementation and ours.
        # NOTE(woosuk): To exactly match the HF implementation, we need to
        # use CPU to compute the cache and then move it to GPU. However, we
        # create the cache on GPU for faster initialization. This may cause
        # a slight numerical difference between the HF implementation and ours.
        inv_freq = 1.0 / (base**(torch.arange(
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        """Compute the cos and sin cache."""
        inv_freq = self._compute_inv_freq(self.base)
        chunk_len = self.chunk_size - self.local_size
        q_t = torch.arange(chunk_len, dtype=torch.float)
        qc_t = (torch.arange(chunk_len, dtype=torch.float) +
                chunk_len).clamp(max=self.chunk_size)
        k_t = torch.arange(self.max_position_embeddings,
                           dtype=torch.float) % chunk_len

        # count from chunk_len, no clamp(self.chunk_size) restriction
        qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
        # count from self.chunk_size for q_inter's rope
        q_inter_t = torch.arange(chunk_len,
                                 dtype=torch.float) + self.chunk_size

        q_freqs = torch.outer(q_t, inv_freq)
        qc_freqs = torch.outer(qc_t, inv_freq)
        k_freqs = torch.outer(k_t, inv_freq)
        qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
        q_inter_freqs = torch.outer(q_inter_t, inv_freq)

        q_cos = q_freqs.cos()
        q_sin = q_freqs.sin()
        qc_cos = qc_freqs.cos()
        qc_sin = qc_freqs.sin()
        k_cos = k_freqs.cos()
        k_sin = k_freqs.sin()

        qc_no_clamp_cos = qc_no_clamp_freqs.cos()
        qc_no_clamp_sin = qc_no_clamp_freqs.sin()
        q_inter_cos = q_inter_freqs.cos()
        q_inter_sin = q_inter_freqs.sin()

        q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
                                                       device=self.device)
        qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
                                                          device=self.device)
        k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
                                                       device=self.device)
        qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
                                      dim=-1).to(dtype=self.dtype,
                                                 device=self.device)
        q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
                                  dim=-1).to(dtype=self.dtype,
                                             device=self.device)
        return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        offsets: Optional[torch.Tensor] = None,
1816
    ) -> tuple[torch.Tensor, torch.Tensor]:
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
        query = query.view(*query.shape[:-1], -1, self.head_size)
        key = key.view(*key.shape[:-1], -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        key_rot = key[..., :self.rotary_dim]
        if self.rotary_dim < self.head_size:
            query_pass = query[..., self.rotary_dim:]
            key_pass = key[..., self.rotary_dim:]
        else:
            query_pass = None
            key_pass = None

        positions_with_offsets = (torch.add(positions, offsets)
                                  if offsets is not None else positions)
        key = self._apply_rotary_embedding(
            self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
        chunk_len = self.chunk_size - self.local_size
        query = self._apply_rotary_embedding(
            self.cos_sin_q_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)
        query_succ = self._apply_rotary_embedding(
            self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)
        query_inter = self._apply_rotary_embedding(
            self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
            query_rot, query_pass)
        query_succ_critical = self._apply_rotary_embedding(
            self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)
        query_inter_critical = self._apply_rotary_embedding(
            self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)

        # merge query into one tensor to simplify the interfaces
        query = torch.cat((
            query,
            query_succ,
            query_inter,
            query_succ_critical,
            query_inter_critical,
        ),
                          dim=-1)
        return query, key

    def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
        cos, sin = cos_sin.chunk(2, dim=-1)
        if self.is_neox_style:
            # NOTE(woosuk): Here we assume that the positions tensor has the
            # shape [batch_size, seq_len].
            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
        else:
            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
        rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
        hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin

        if self.rotary_dim < self.head_size:
            hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
        else:
            hidden = hidden_rot
        return hidden.flatten(-2).squeeze(0)

    def extra_repr(self) -> str:
        s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
        s += f", max_position_embeddings={self.max_position_embeddings}"
        s += f", base={self.base}, is_neox_style={self.is_neox_style}"
        s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
        return s


1887
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
1888
1889


1890
1891
1892
1893
def get_rope(
    head_size: int,
    rotary_dim: int,
    max_position: int,
1894
    base: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
1895
    is_neox_style: bool = True,
1896
    rope_scaling: Optional[dict[str, Any]] = None,
1897
    dtype: Optional[torch.dtype] = None,
1898
    partial_rotary_factor: float = 1.0,
1899
    dual_chunk_attention_config: Optional[dict[str, Any]] = None,
1900
) -> RotaryEmbedding:
1901
1902
    if dtype is None:
        dtype = torch.get_default_dtype()
1903
1904
1905
1906
1907
1908
1909
1910
1911
    if rope_scaling is not None:
        # Transforms every value that is a list into a tuple for caching calls
        rope_scaling_tuple = {
            k: tuple(v) if isinstance(v, list) else v
            for k, v in rope_scaling.items()
        }
        rope_scaling_args = tuple(rope_scaling_tuple.items())
    else:
        rope_scaling_args = None
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922

    if dual_chunk_attention_config is not None:
        dual_chunk_attention_tuple = {
            k: tuple(v) if isinstance(v, list) else v
            for k, v in dual_chunk_attention_config.items()
            if k != "sparse_attention_config"
        }
        dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
    else:
        dual_chunk_attention_args = None

1923
1924
    if partial_rotary_factor < 1.0:
        rotary_dim = int(rotary_dim * partial_rotary_factor)
1925
    key = (head_size, rotary_dim, max_position, base, is_neox_style,
1926
           rope_scaling_args, dual_chunk_attention_args, dtype)
1927
1928
    if key in _ROPE_DICT:
        return _ROPE_DICT[key]
1929

1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
    if dual_chunk_attention_config is not None:
        extra_kwargs = {
            k: v
            for k, v in dual_chunk_attention_config.items()
            if k in ("chunk_size", "local_size")
        }
        rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
                                              max_position, base,
                                              is_neox_style, dtype,
                                              **extra_kwargs)
    elif not rope_scaling:
1941
        rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
1942
                                     is_neox_style, dtype)
1943
    else:
1944
1945
        scaling_type = rope_scaling["rope_type"]

1946
        if scaling_type == "llama3":
1947
            scaling_factor = rope_scaling["factor"]
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
            low_freq_factor = rope_scaling["low_freq_factor"]
            high_freq_factor = rope_scaling["high_freq_factor"]
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
                                               max_position, base,
                                               is_neox_style, dtype,
                                               scaling_factor, low_freq_factor,
                                               high_freq_factor,
                                               original_max_position)
1958
1959
1960
1961
        elif scaling_type == "mllama4":
            rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
                                                     max_position, base,
                                                     is_neox_style, dtype)
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
        elif scaling_type == "default":
            if "mrope_section" in rope_scaling:
                rotary_emb = MRotaryEmbedding(
                    head_size,
                    rotary_dim,
                    max_position,
                    base,
                    is_neox_style,
                    dtype,
                    mrope_section=rope_scaling["mrope_section"],
                )
            else:
                rotary_emb = RotaryEmbedding(
                    head_size,
                    rotary_dim,
                    max_position,
                    base,
                    is_neox_style,
                    dtype,
                )
1982
        elif scaling_type == "linear":
1983
            scaling_factor = rope_scaling["factor"]
1984
1985
1986
            rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
                                                      max_position, base,
                                                      is_neox_style,
1987
                                                      scaling_factor, dtype)
1988
1989
1990
1991
1992
1993
1994
1995
        elif scaling_type == "ntk":
            scaling_factor = rope_scaling["factor"]
            mixed_b = rope_scaling.get('mixed_b', None)
            rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
                                                   max_position, base,
                                                   is_neox_style,
                                                   scaling_factor, dtype,
                                                   mixed_b)
1996
        elif scaling_type == "dynamic":
1997
1998
            if "alpha" in rope_scaling:
                scaling_alpha = rope_scaling["alpha"]
1999
2000
2001
                rotary_emb = DynamicNTKAlphaRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_alpha, dtype)
2002
2003
            elif "factor" in rope_scaling:
                scaling_factor = rope_scaling["factor"]
2004
2005
2006
                rotary_emb = DynamicNTKScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor, dtype)
2007
2008
2009
            else:
                raise ValueError("Dynamic rope scaling must contain either "
                                 "'alpha' or 'factor' field")
2010
        elif scaling_type == "yarn":
2011
            scaling_factor = rope_scaling["factor"]
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            extra_kwargs = {
                k: v
                for k, v in rope_scaling.items()
                if k in ("extrapolation_factor", "attn_factor", "beta_fast",
                         "beta_slow")
            }
            rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
                                                    original_max_position,
                                                    base, is_neox_style,
2023
                                                    scaling_factor, dtype,
2024
                                                    **extra_kwargs)
wangding zeng's avatar
wangding zeng committed
2025
        elif scaling_type == "deepseek_yarn":
2026
            scaling_factor = rope_scaling["factor"]
wangding zeng's avatar
wangding zeng committed
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            # assert max_position == original_max_position * scaling_factor
            extra_kwargs = {
                k: v
                for k, v in rope_scaling.items()
                if k in ("extrapolation_factor", "attn_factor", "beta_fast",
                         "beta_slow", "mscale", "mscale_all_dim")
            }
            rotary_emb = DeepseekScalingRotaryEmbedding(
                head_size, rotary_dim, original_max_position, base,
                is_neox_style, scaling_factor, dtype, **extra_kwargs)
2039
        elif scaling_type == "longrope":
2040
2041
2042
2043
2044
2045
2046
2047
2048
            short_factor = rope_scaling["short_factor"]
            long_factor = rope_scaling["long_factor"]
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            extra_kwargs = {
                k: v
                for k, v in rope_scaling.items()
                if k in ("short_mscale", "long_mscale")
            }
2049
            rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
2050
                head_size, rotary_dim, max_position, original_max_position,
2051
2052
                base, is_neox_style, dtype, short_factor, long_factor,
                **extra_kwargs)
2053
2054
        else:
            raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
2055
    _ROPE_DICT[key] = rotary_emb
2056
    return rotary_emb