layers.py 23.4 KB
Newer Older
1
import os
2
import torch
3
import torch.distributed
4
5

from torch import nn
6
from torch.nn import functional as F
7
from typing import List
8
9
from loguru import logger
from functools import lru_cache
10
11
12

HAS_BITS_AND_BYTES = True
try:
13
    import bitsandbytes as bnb
Nicolas Patry's avatar
Nicolas Patry committed
14
    from bitsandbytes.nn import Int8Params, Params4bit
15
16

except ImportError:
17
18
    HAS_BITS_AND_BYTES = False

19
20
from accelerate import init_empty_weights

21
from text_generation_server.utils.gptq.quant_linear import QuantLinear
22

23
24

HAS_AWQ = True
OlivierDehaene's avatar
OlivierDehaene committed
25
try:
26
27
28
29
    from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
    HAS_AWQ = False

30
try:
31
32
33
34
35
36
    major, _minor = torch.cuda.get_device_capability()
except Exception:
    major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
if os.getenv("DISABLE_EXLLAMA") == "True":
37
    HAS_EXLLAMA = False
38
elif CAN_EXLLAMA:
OlivierDehaene's avatar
OlivierDehaene committed
39
40
41
42
43
44
    try:
        from text_generation_server.utils.gptq.exllama import Ex4bitLinear

        HAS_EXLLAMA = True
    except ImportError:
        pass
45

46
from typing import Optional
47

48
49
50
HAS_EETQ = False
try:
    from EETQ import quant_weights, w8_a16_gemm
OlivierDehaene's avatar
OlivierDehaene committed
51

52
53
54
55
    HAS_EETQ = True
except ImportError:
    pass

56
57
58
59
60
61
62
63
64
65
66
67
68
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


69
70
71
72
73
74
75
76
77
78
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = None
    return ln

OlivierDehaene's avatar
OlivierDehaene committed
79

80
81
82
83
84
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
85
86
87
88
89
90
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
91
92
93
94
95
96
97

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = nn.Parameter(bias)
    return conv2d


@classmethod
OlivierDehaene's avatar
OlivierDehaene committed
98
99
100
def load_conv2d_no_bias(
    cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
101
102
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
103
104
105
106
107
108
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
109
110
111
112
113

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = None
    return conv2d

114

115
116
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
117
torch.nn.LayerNorm.load = load_layer_norm
118
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
119

120
121

class FastLinear(nn.Module):
122
123
    def __init__(
        self,
124
125
        weight,
        bias,
126
    ) -> None:
127
128
129
130
131
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
132
            self.bias = None
133
134
135
136
137
138

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
139
        else:
140
141
            bias = None
        return cls(weight, bias)
142
143

    def forward(self, input: torch.Tensor) -> torch.Tensor:
144
        return F.linear(input, self.weight, self.bias)
145
146


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class EETQLinear(nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        device = weight.device
        weight = torch.t(weight).contiguous().cpu()
        weight, scale = quant_weights(weight, torch.int8, False)
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        self.weight = weight.cuda(device)
        self.scale = scale.cuda(device)
        self.bias = bias.cuda(device) if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = w8_a16_gemm(input, self.weight, self.scale)
        output = output + self.bias if self.bias is not None else output
        return output


171
class Linear8bitLt(nn.Module):
172
173
    def __init__(
        self,
174
175
176
177
178
179
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
180
    ):
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
199
        )
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
227
228


Nicolas Patry's avatar
Nicolas Patry committed
229
230
231
232
class Linear4bit(nn.Module):
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
OlivierDehaene's avatar
OlivierDehaene committed
233
234
235
236
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
Nicolas Patry's avatar
Nicolas Patry committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out


265
266
@lru_cache(1)
def warn_deprecate_bnb():
OlivierDehaene's avatar
OlivierDehaene committed
267
268
269
270
    logger.warning(
        "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
    )

271

272
273
274
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
275
276
277
278
    elif quantize == "eetq":
        if HAS_EETQ:
            linear = EETQLinear(weight, bias)
        else:
OlivierDehaene's avatar
OlivierDehaene committed
279
280
281
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
282
    elif quantize == "bitsandbytes":
283
        warn_deprecate_bnb()
284
285
286
287
288
289
290
291
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
Nicolas Patry's avatar
Nicolas Patry committed
292
293
294
295
296
297
298
299
300
301
302
303
    elif quantize == "bitsandbytes-fp4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
304
    elif quantize == "gptq":
305
        try:
306
            qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
307
308
309
310
311
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

312
313
314
315
316
317
318
319
320
321
322
323
        if use_exllama:
            linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
        else:
            linear = QuantLinear(
                qweight,
                qzeros,
                scales,
                g_idx,
                bias,
                bits,
                groupsize,
            )
324
325
326
327
328
329
330
    elif quantize == "awq":
        try:
            qweight, qzeros, scales, _, bits, groupsize, _ = weight
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `awq` compatible, loader needs to be updated."
            )
OlivierDehaene's avatar
OlivierDehaene committed
331
332
333
334
335
336
337
338
        linear = WQLinear(
            w_bit=bits,
            group_size=groupsize,
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            bias=bias is not None,
        )
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


class TensorParallelHead(SuperLayer):
354
    def __init__(self, linear, process_group, should_gather: bool):
355
        super().__init__(linear)
356
        self.process_group = process_group
357
        self.should_gather = should_gather
358
359
360

    @staticmethod
    def load(config, prefix: str, weights):
361
362
363
364
365
366
367
368
369
370
371
372
        if weights.process_group.size() > 1:
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False
373

374
375
        # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
        if config.quantize in ["gptq", "awq", "eetq"]:
376
377
378
            quantize = None
        else:
            quantize = config.quantize
379
        return TensorParallelHead(
380
            get_linear(weight, bias=None, quantize=quantize),
381
            process_group=weights.process_group,
382
            should_gather=should_gather,
383
384
385
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
OlivierDehaene's avatar
OlivierDehaene committed
386
387
388
        if not self.should_gather:
            return super().forward(input)

389
        world_size = self.process_group.size()
OlivierDehaene's avatar
OlivierDehaene committed
390
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
391
392
            out_dim = self.linear.weight.shape[0]

OlivierDehaene's avatar
OlivierDehaene committed
393
394
395
396
397
398
399
400
            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T
401
402
403
404

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
OlivierDehaene's avatar
OlivierDehaene committed
405
                world_out, gather_input, group=self.process_group
406
407
            )

OlivierDehaene's avatar
OlivierDehaene committed
408
409
410
            if input.shape[0] == 1:
                return world_out
            return world_out.T
411

OlivierDehaene's avatar
OlivierDehaene committed
412
413
414
415
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
416
417
418
419
420
421
422
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
xiaobin's avatar
xiaobin committed
423
424
    def load_qkv(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
OlivierDehaene's avatar
OlivierDehaene committed
425
        weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
xiaobin's avatar
xiaobin committed
426
427
428
429
430
431
432
433
        if bias:
            raise NotImplementedError("packed_qkv only implemented for baichuan")
        else:
            bias = None
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)

    @classmethod
434
    def load(cls, config, prefix: str, weights, bias: bool):
435
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
436

437
438
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
439
440
441
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
442

443
444
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
445
            bias = torch.cat(b, dim=dim)
446
447
        else:
            bias = None
448
449
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
450

451
452
453
454

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
455
456
        self.process_group = process_group

457
458
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
459
460
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

461
462
463
464
465
466
467
468
469
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
470

471
472
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = super().forward(input)
473
474
        if self.process_group.size() > 1:
            torch.distributed.all_reduce(out, group=self.process_group)
475
        return out
476
477


478
479
480
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
481
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
482
483
484
485
486
487
488
489
490
491
492
493
494
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

        block_size = num_embeddings // world_size
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
        self.null_idx = block_size
        self.process_group = weights.process_group
        self.reduce = reduce
495
496

        """Additional 0 entry used for masking"""
497
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
498
499
500
501
502
503
504
505
506

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
507
        out = torch.nn.functional.embedding(input, self.weight)
508
        if self.reduce and self.process_group.size() > 1:
509
            torch.distributed.all_reduce(out, group=self.process_group)
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        return out


try:
    import dropout_layer_norm

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual

except ImportError:
    pass


try:
    from flash_attn.layers.rotary import RotaryEmbedding
    import rotary_emb

Nicolas Patry's avatar
Nicolas Patry committed
559
560
    def _create_inv_freq(dim, base, device):
        inv_freq = 1.0 / (
OlivierDehaene's avatar
OlivierDehaene committed
561
            base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
Nicolas Patry's avatar
Nicolas Patry committed
562
563
564
565
566
        )
        return inv_freq

    def _get_rope_config(config):
        if os.getenv("ROPE_SCALING", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
567
568
569
570
            rope_scaling = {
                "type": os.environ["ROPE_SCALING"],
                "factor": float(os.environ["ROPE_FACTOR"]),
            }
Nicolas Patry's avatar
Nicolas Patry committed
571
572
573
            return rope_scaling
        return getattr(config, "rope_scaling", None)

574
    class PositionRotaryEmbedding(nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
575
        def __init__(self, inv_freq, scaling_factor):
576
            super().__init__()
577
            self.inv_freq = inv_freq
578
579
580
581
582
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None
Nicolas Patry's avatar
Nicolas Patry committed
583
584
            self.scaling_factor = scaling_factor
            self.dynamic_args = None
585
586

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
587
588
589
590
591
592
593
594
595
        def static(cls, config, dim, base, device):
            inv_freq = _create_inv_freq(dim, base, device)
            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
596
597
598
599
600
601
602
                    return DynamicPositionRotaryEmbedding(
                        dim=dim,
                        max_position_embeddings=config.max_position_embeddings,
                        base=base,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
603
                else:
OlivierDehaene's avatar
OlivierDehaene committed
604
605
606
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
607
            return cls(inv_freq, scaling_factor)
608
609

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
610
        def load(cls, config, prefix, weights):
611
612
613
614
615
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
Nicolas Patry's avatar
Nicolas Patry committed
616
617
618
619
620
621
622
623

            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
624
625
626
627
628
629
630
                    return DynamicPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=config.max_position_embeddings,
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
631
                else:
OlivierDehaene's avatar
OlivierDehaene committed
632
633
634
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
635
            return cls(inv_freq, scaling_factor)
636

637
638
639
640
641
642
643
644
645
646
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Nicolas Patry's avatar
Nicolas Patry committed
647
648
                if self.scaling_factor is not None:
                    t /= self.scaling_factor
649
650
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Nicolas Patry's avatar
Nicolas Patry committed
651

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
        ):
            """
            Return cos and sin for the asked position ids
            """

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
            return cos.unsqueeze(1), sin.unsqueeze(1)

669
        def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
670
            rotary_dim = cos.shape[-1]
671
672
673
674
675
            x1 = x[..., :rotary_dim]
            x2 = x[..., rotary_dim : 2 * rotary_dim]

            rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
            return x
676

Nicolas Patry's avatar
Nicolas Patry committed
677
678
    class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
Nicolas Patry's avatar
Nicolas Patry committed
679
            inv_freq = _create_inv_freq(dim, base, device)
Nicolas Patry's avatar
Nicolas Patry committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                if seqlen > self.max_position_embeddings:
OlivierDehaene's avatar
OlivierDehaene committed
694
695
696
697
698
699
700
                    newbase = self.base * (
                        (self.scaling_factor * seqlen / self.max_position_embeddings)
                        - (self.scaling_factor - 1)
                    ) ** (self.dim / (self.dim - 2))
                    self.inv_freq = _create_inv_freq(
                        self.dim, newbase, self.inv_freq.device
                    )
Nicolas Patry's avatar
Nicolas Patry committed
701
702
703
704
705
706
707
708
709
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

710
711
except ImportError:
    pass