rotary.py 17.2 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
import os
import torch
from torch import nn
4
from loguru import logger
Nicolas Patry's avatar
Nicolas Patry committed
5

Nicolas Patry's avatar
Nicolas Patry committed
6
from text_generation_server.utils.import_utils import SYSTEM
Nicolas Patry's avatar
Nicolas Patry committed
7
8
9
10
11

if SYSTEM == "cuda":
    from flash_attn.layers.rotary import RotaryEmbedding
    import rotary_emb
elif SYSTEM == "rocm":
fxmarty's avatar
fxmarty committed
12
    from vllm._C import ops
Nicolas Patry's avatar
Nicolas Patry committed
13
elif SYSTEM == "ipex":
Wang, Yi's avatar
Wang, Yi committed
14
    import intel_extension_for_pytorch as ipex
Nicolas Patry's avatar
Nicolas Patry committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71


def _create_inv_freq(dim, base, device):
    inv_freq = 1.0 / (
        base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
    )
    return inv_freq


def _get_rope_config(config):
    if os.getenv("ROPE_SCALING", None) is not None:
        rope_scaling = {
            "type": os.environ["ROPE_SCALING"],
            "factor": float(os.environ["ROPE_FACTOR"]),
        }
        return rope_scaling
    return getattr(config, "rope_scaling", None)


class PositionRotaryEmbedding(nn.Module):
    def __init__(self, inv_freq, scaling_factor):
        super().__init__()
        self.inv_freq = inv_freq
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None
        self._cos_k_cached = None
        self._sin_k_cached = None
        self.scaling_factor = scaling_factor
        self.dynamic_args = None

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ):
        # Such controlflows may add some overhead.
        if SYSTEM == "cuda":
            rotary_dim = cos.shape[-1]
            q1 = query[..., :rotary_dim]
            q2 = query[..., rotary_dim : 2 * rotary_dim]

            rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)

            k1 = key[..., :rotary_dim]
            k2 = key[..., rotary_dim : 2 * rotary_dim]

            rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
        elif SYSTEM == "rocm":
            # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
            # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

            head_size = query.shape[-1]

            # Inplace operation, updating query and key.
fxmarty's avatar
fxmarty committed
72
            ops.rotary_embedding(query, key, head_size, cos, sin, True)
Nicolas Patry's avatar
Nicolas Patry committed
73
        elif SYSTEM == "ipex":
Nicolas Patry's avatar
Nicolas Patry committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            ipex.llm.functional.rotary_embedding(
                query, key, sin, cos, query.size(-1), True
            )
        else:
            raise ValueError(
                "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
            )

    @classmethod
    def static(cls, config, dim, base, device):
        inv_freq = _create_inv_freq(dim, base, device)
        scaling_factor = None
        rope_scaling = _get_rope_config(config)
        if rope_scaling is not None:
            if rope_scaling["type"] == "linear":
                pass
            elif rope_scaling["type"] == "dynamic":
                scaling_factor = rope_scaling["factor"]
                return DynamicPositionRotaryEmbedding(
                    dim=dim,
                    max_position_embeddings=config.max_position_embeddings,
                    base=base,
                    device=inv_freq.device,
                    scaling_factor=scaling_factor,
                )
            elif rope_scaling["type"] == "yarn":
                scaling_factor = rope_scaling["factor"]
101
102
                mscale = rope_scaling.get("mscale", 1.0)
                mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
Nicolas Patry's avatar
Nicolas Patry committed
103
104
105
106
107
                return YarnPositionRotaryEmbedding(
                    dim=2 * inv_freq.shape[0],
                    max_position_embeddings=rope_scaling[
                        "original_max_position_embeddings"
                    ],
108
                    base=base,
Nicolas Patry's avatar
Nicolas Patry committed
109
110
111
112
113
114
                    device=inv_freq.device,
                    scaling_factor=scaling_factor,
                    extrapolation_factor=1,
                    attn_factor=1,
                    beta_fast=32,
                    beta_slow=1,
115
116
                    mscale=mscale,
                    mscale_all_dim=mscale_all_dim,
Nicolas Patry's avatar
Nicolas Patry committed
117
                )
118
            elif rope_scaling["type"] in ["su", "longrope"]:
Nicolas Patry's avatar
Nicolas Patry committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                short_factor = torch.tensor(
                    rope_scaling["short_factor"], dtype=torch.float32, device=device
                )
                short_inv_freq = 1.0 / (
                    short_factor
                    * base
                    ** (
                        torch.arange(0, dim, 2, device=device, dtype=torch.float32)
                        / dim
                    )
                )
                long_factor = torch.tensor(
                    rope_scaling["long_factor"], dtype=torch.float32, device=device
                )
                long_inv_freq = 1.0 / (
                    long_factor
                    * base
                    ** (
                        torch.arange(0, dim, 2, device=device, dtype=torch.float32)
                        / dim
                    )
                )

                original_max_position_embeddings = (
                    config.original_max_position_embeddings
                )
                max_position_embeddings = config.max_position_embeddings
                if max_position_embeddings <= original_max_position_embeddings:
                    scaling_factor = 1.0
                else:
                    scale = max_position_embeddings / original_max_position_embeddings
                    scaling_factor = math.sqrt(
                        1 + math.log(scale) / math.log(original_max_position_embeddings)
                    )

                return SuRotaryEmbedding(
                    short_inv_freq=short_inv_freq,
                    long_inv_freq=long_inv_freq,
                    scaling_factor=scaling_factor,
                    original_max_position_embeddings=original_max_position_embeddings,
                )
            else:
                raise NotImplementedError(
                    f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                )
        return cls(inv_freq, scaling_factor)

    @classmethod
    def load(cls, config, prefix, weights):
        # XXX: Always load this in float32 !
        dtype = weights.dtype
        weights.dtype = torch.float32
        inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
        weights.dtype = dtype

        scaling_factor = None
        rope_scaling = _get_rope_config(config)
        if rope_scaling is not None:
            scaling_factor = rope_scaling["factor"]
            if rope_scaling["type"] == "linear":
                pass
            elif rope_scaling["type"] == "dynamic":
                return DynamicPositionRotaryEmbedding(
                    dim=2 * inv_freq.shape[0],
                    max_position_embeddings=config.max_position_embeddings,
                    base=10000.0,
                    device=inv_freq.device,
                    scaling_factor=scaling_factor,
                )
            elif rope_scaling["type"] == "yarn":
189
190
                mscale = rope_scaling.get("mscale", 1.0)
                mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
Nicolas Patry's avatar
Nicolas Patry committed
191
192
193
194
195
196
197
198
199
200
201
202
                return YarnPositionRotaryEmbedding(
                    dim=2 * inv_freq.shape[0],
                    max_position_embeddings=rope_scaling[
                        "original_max_position_embeddings"
                    ],
                    base=10000.0,
                    device=inv_freq.device,
                    scaling_factor=scaling_factor,
                    extrapolation_factor=1,
                    attn_factor=1,
                    beta_fast=32,
                    beta_slow=1,
203
204
                    mscale=mscale,
                    mscale_all_dim=mscale_all_dim,
Nicolas Patry's avatar
Nicolas Patry committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                )
            else:
                raise NotImplementedError(
                    f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                )
        return cls(inv_freq, scaling_factor)

    def _update_cos_sin_cache(self, dtype, device, seqlen):
        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
        ):
            self._seq_len_cached = seqlen
            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
            if self.scaling_factor is not None:
                t /= self.scaling_factor
            # Don't do einsum, it converts fp32 to fp16
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
            self._cos_cached = torch.cos(freqs).to(dtype)
            self._sin_cached = torch.sin(freqs).to(dtype)

    def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
        """
        Return cos and sin for the asked position ids
        """
        if SYSTEM == "rocm":
            # For RoCm, we always use float cos/sin to avoid a cast.
            # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
            # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
            dtype = torch.float32

        self._update_cos_sin_cache(dtype, position_ids.device, max_s)

        cos = torch.index_select(self._cos_cached, 0, position_ids)
        sin = torch.index_select(self._sin_cached, 0, position_ids)

        # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
        return cos.unsqueeze(1), sin.unsqueeze(1)


class SuRotaryEmbedding(PositionRotaryEmbedding):
    def __init__(
        self,
        short_inv_freq,
        long_inv_freq,
        scaling_factor,
        original_max_position_embeddings,
    ):
        super(PositionRotaryEmbedding, self).__init__()
        self.short_inv_freq = short_inv_freq
        self.long_inv_freq = long_inv_freq
        self.scaling_factor = scaling_factor
        self.original_max_position_embeddings = original_max_position_embeddings
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None
        self._cos_k_cached = None
        self._sin_k_cached = None
        self.dynamic_args = None

    def _update_cos_sin_cache(self, dtype, device, seqlen):
        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
        ):
            self._seq_len_cached = seqlen

280
281
282
283
284
285
286
287
288
289
290
291
292
293
            t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
            short_freqs = torch.outer(
                t[: self.original_max_position_embeddings],
                self.short_inv_freq.to(device=t.device),
            )
            long_freqs = torch.outer(
                t[self.original_max_position_embeddings :],
                self.long_inv_freq.to(device=t.device),
            )

            freqs = torch.cat([short_freqs, long_freqs])

            self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
            self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
Nicolas Patry's avatar
Nicolas Patry committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357


class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
    def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
        inv_freq = _create_inv_freq(dim, base, device)
        super().__init__(inv_freq, scaling_factor)
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base

    def _update_cos_sin_cache(self, dtype, device, seqlen):
        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
        ):
            if seqlen > self.max_position_embeddings:
                newbase = self.base * (
                    (self.scaling_factor * seqlen / self.max_position_embeddings)
                    - (self.scaling_factor - 1)
                ) ** (self.dim / (self.dim - 2))
                self.inv_freq = _create_inv_freq(
                    self.dim, newbase, self.inv_freq.device
                )
            self._seq_len_cached = seqlen
            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
            # Don't do einsum, it converts fp32 to fp16
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
            self._cos_cached = torch.cos(freqs).to(dtype)
            self._sin_cached = torch.sin(freqs).to(dtype)


# Inverse dim formula to find dim based on number of rotations
import math


def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
        2 * math.log(base)
    )


# Find dim range bounds based on rotations
def find_correction_range(
    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
    low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
    high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case


def linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func


358
def get_mscale(scale: float = 1.0, mscale: float = 1.0):
Nicolas Patry's avatar
Nicolas Patry committed
359
360
    if scale <= 1:
        return 1.0
361
    return 0.1 * mscale * math.log(scale) + 1.0
Nicolas Patry's avatar
Nicolas Patry committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376


class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
    def __init__(
        self,
        dim,
        max_position_embeddings,
        base,
        device,
        scaling_factor,
        *,
        extrapolation_factor,
        attn_factor,
        beta_fast,
        beta_slow,
377
378
        mscale: float,
        mscale_all_dim: float,
Nicolas Patry's avatar
Nicolas Patry committed
379
380
381
382
383
384
385
386
387
388
    ):
        inv_freq = _create_inv_freq(dim, base, device)
        super().__init__(inv_freq, scaling_factor)
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
389
390
        self.mscale_all_dim = mscale_all_dim
        self.scaling_factor = scaling_factor
Nicolas Patry's avatar
Nicolas Patry committed
391
        self.mscale = float(
392
393
394
            get_mscale(self.scaling_factor, mscale)
            / get_mscale(self.scaling_factor, mscale_all_dim)
            * self.attn_factor
Nicolas Patry's avatar
Nicolas Patry committed
395
396
397
398
399
400
401
402
403
404
        )  # Get n-d magnitude scaling corrected for interpolation

    def _update_cos_sin_cache(self, dtype, device, seqlen):
        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
        ):
405
            if seqlen > self.max_position_embeddings or True:
Nicolas Patry's avatar
Nicolas Patry committed
406
407
408
409
410
411
412
413
414
415
416
417
                inv_freq_extrapolation = _create_inv_freq(
                    self.dim, self.base, self.inv_freq.device
                )
                freqs = 1.0 / inv_freq_extrapolation
                inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
                low, high = find_correction_range(
                    self.beta_fast,
                    self.beta_slow,
                    self.dim,
                    self.base,
                    self.max_position_embeddings,
                )
418

Nicolas Patry's avatar
Nicolas Patry committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                inv_freq_mask = (
                    1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
                ) * self.extrapolation_factor  # Get n-d rotational scaling corrected for extrapolation
                inv_freq = (
                    inv_freq_interpolation * (1 - inv_freq_mask)
                    + inv_freq_extrapolation * inv_freq_mask
                )

                self.inv_freq = inv_freq

            self._seq_len_cached = seqlen
            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
            # Don't do einsum, it converts fp32 to fp16
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
            self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
            self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)