layers.py 28.2 KB
Newer Older
1
import os
2
import torch
3
import torch.distributed
4
5

from torch import nn
6
from torch.nn import functional as F
7
from typing import List
8
9
from loguru import logger
from functools import lru_cache
10
11
12

HAS_BITS_AND_BYTES = True
try:
13
    import bitsandbytes as bnb
Nicolas Patry's avatar
Nicolas Patry committed
14
    from bitsandbytes.nn import Int8Params, Params4bit
15
16

except ImportError:
17
18
    HAS_BITS_AND_BYTES = False

19
20
from accelerate import init_empty_weights

21
from text_generation_server.utils.gptq.quant_linear import QuantLinear
22

23
24

HAS_AWQ = True
OlivierDehaene's avatar
OlivierDehaene committed
25
try:
26
27
28
29
    from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
    HAS_AWQ = False

30
try:
31
32
33
34
35
36
    major, _minor = torch.cuda.get_device_capability()
except Exception:
    major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
if os.getenv("DISABLE_EXLLAMA") == "True":
37
    HAS_EXLLAMA = False
38
elif CAN_EXLLAMA:
OlivierDehaene's avatar
OlivierDehaene committed
39
40
41
42
43
44
    try:
        from text_generation_server.utils.gptq.exllama import Ex4bitLinear

        HAS_EXLLAMA = True
    except ImportError:
        pass
45

46
from typing import Optional
47

48
49
50
HAS_EETQ = False
try:
    from EETQ import quant_weights, w8_a16_gemm
OlivierDehaene's avatar
OlivierDehaene committed
51

52
53
54
55
    HAS_EETQ = True
except ImportError:
    pass

56

57
58
59
60
61
62
63
64
65
66
67
68
69
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


70
71
72
73
74
75
76
77
78
79
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = None
    return ln

OlivierDehaene's avatar
OlivierDehaene committed
80

81
82
83
84
85
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
86
87
88
89
90
91
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
92
93
94
95
96
97
98

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = nn.Parameter(bias)
    return conv2d


@classmethod
OlivierDehaene's avatar
OlivierDehaene committed
99
100
101
def load_conv2d_no_bias(
    cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
102
103
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
104
105
106
107
108
109
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
110
111
112
113
114

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = None
    return conv2d

115

116
117
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
118
torch.nn.LayerNorm.load = load_layer_norm
119
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
120

121
122

class FastLinear(nn.Module):
123
124
    def __init__(
        self,
125
126
        weight,
        bias,
127
    ) -> None:
128
129
130
131
132
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
133
            self.bias = None
134
135
136
137
138
139

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
140
        else:
141
142
            bias = None
        return cls(weight, bias)
143
144

    def forward(self, input: torch.Tensor) -> torch.Tensor:
145
        return F.linear(input, self.weight, self.bias)
146
147


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class EETQLinear(nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        device = weight.device
        weight = torch.t(weight).contiguous().cpu()
        weight, scale = quant_weights(weight, torch.int8, False)
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        self.weight = weight.cuda(device)
        self.scale = scale.cuda(device)
        self.bias = bias.cuda(device) if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = w8_a16_gemm(input, self.weight, self.scale)
        output = output + self.bias if self.bias is not None else output
        return output


172
class Linear8bitLt(nn.Module):
173
174
    def __init__(
        self,
175
176
177
178
179
180
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
181
    ):
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
200
        )
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
228
229


Nicolas Patry's avatar
Nicolas Patry committed
230
231
232
233
class Linear4bit(nn.Module):
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
OlivierDehaene's avatar
OlivierDehaene committed
234
235
236
237
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
Nicolas Patry's avatar
Nicolas Patry committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out


266
267
@lru_cache(1)
def warn_deprecate_bnb():
OlivierDehaene's avatar
OlivierDehaene committed
268
269
270
271
    logger.warning(
        "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
    )

272

273
274
275
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
276
277
278
279
    elif quantize == "eetq":
        if HAS_EETQ:
            linear = EETQLinear(weight, bias)
        else:
OlivierDehaene's avatar
OlivierDehaene committed
280
281
282
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
283
    elif quantize == "bitsandbytes":
284
        warn_deprecate_bnb()
285
286
287
288
289
290
291
292
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
Nicolas Patry's avatar
Nicolas Patry committed
293
294
295
296
297
298
299
300
301
302
303
304
    elif quantize == "bitsandbytes-fp4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
305
    elif quantize == "gptq":
306
        try:
307
            qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
308
309
310
311
312
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

313
314
315
316
317
318
319
320
321
322
323
324
        if use_exllama:
            linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
        else:
            linear = QuantLinear(
                qweight,
                qzeros,
                scales,
                g_idx,
                bias,
                bits,
                groupsize,
            )
325
326
327
328
329
330
331
    elif quantize == "awq":
        try:
            qweight, qzeros, scales, _, bits, groupsize, _ = weight
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `awq` compatible, loader needs to be updated."
            )
OlivierDehaene's avatar
OlivierDehaene committed
332
333
334
335
336
337
338
339
        linear = WQLinear(
            w_bit=bits,
            group_size=groupsize,
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            bias=bias is not None,
        )
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


class TensorParallelHead(SuperLayer):
355
    def __init__(self, linear, process_group, should_gather: bool):
356
        super().__init__(linear)
357
        self.process_group = process_group
358
        self.should_gather = should_gather
359
360
361

    @staticmethod
    def load(config, prefix: str, weights):
362
363
364
365
366
367
368
369
370
371
372
373
        if weights.process_group.size() > 1:
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False
374

375
376
        # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
        if config.quantize in ["gptq", "awq", "eetq"]:
377
378
379
            quantize = None
        else:
            quantize = config.quantize
380
        return TensorParallelHead(
381
            get_linear(weight, bias=None, quantize=quantize),
382
            process_group=weights.process_group,
383
            should_gather=should_gather,
384
385
386
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
OlivierDehaene's avatar
OlivierDehaene committed
387
388
389
        if not self.should_gather:
            return super().forward(input)

390
        world_size = self.process_group.size()
OlivierDehaene's avatar
OlivierDehaene committed
391
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
392
393
            out_dim = self.linear.weight.shape[0]

OlivierDehaene's avatar
OlivierDehaene committed
394
395
396
397
398
399
400
401
            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T
402
403
404
405

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
OlivierDehaene's avatar
OlivierDehaene committed
406
                world_out, gather_input, group=self.process_group
407
408
            )

OlivierDehaene's avatar
OlivierDehaene committed
409
410
411
            if input.shape[0] == 1:
                return world_out
            return world_out.T
412

OlivierDehaene's avatar
OlivierDehaene committed
413
414
415
416
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
417
418
419
420
421
422
423
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
xiaobin's avatar
xiaobin committed
424
425
    def load_qkv(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
OlivierDehaene's avatar
OlivierDehaene committed
426
        weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
xiaobin's avatar
xiaobin committed
427
428
429
430
431
432
433
434
        if bias:
            raise NotImplementedError("packed_qkv only implemented for baichuan")
        else:
            bias = None
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)

    @classmethod
435
    def load(cls, config, prefix: str, weights, bias: bool):
436
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
437

438
439
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
440
441
442
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
443

444
445
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
446
            bias = torch.cat(b, dim=dim)
447
448
        else:
            bias = None
449
450
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
451

452
453
454
455

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
456
457
        self.process_group = process_group

458
459
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
460
461
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

462
463
464
465
466
467
468
469
470
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
471

472
473
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = super().forward(input)
474
475
        if self.process_group.size() > 1:
            torch.distributed.all_reduce(out, group=self.process_group)
476
        return out
477
478


479
480
481
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
482
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
483
484
485
486
487
488
489
490
491
492
493
494
495
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

        block_size = num_embeddings // world_size
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
        self.null_idx = block_size
        self.process_group = weights.process_group
        self.reduce = reduce
496
497

        """Additional 0 entry used for masking"""
498
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
499
500
501
502
503
504
505
506
507

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
508
        out = torch.nn.functional.embedding(input, self.weight)
509
        if self.reduce and self.process_group.size() > 1:
510
            torch.distributed.all_reduce(out, group=self.process_group)
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        return out


try:
    import dropout_layer_norm

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual

except ImportError:
    pass


try:
    from flash_attn.layers.rotary import RotaryEmbedding
    import rotary_emb

Nicolas Patry's avatar
Nicolas Patry committed
560
561
    def _create_inv_freq(dim, base, device):
        inv_freq = 1.0 / (
OlivierDehaene's avatar
OlivierDehaene committed
562
            base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
Nicolas Patry's avatar
Nicolas Patry committed
563
564
565
566
567
        )
        return inv_freq

    def _get_rope_config(config):
        if os.getenv("ROPE_SCALING", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
568
569
570
571
            rope_scaling = {
                "type": os.environ["ROPE_SCALING"],
                "factor": float(os.environ["ROPE_FACTOR"]),
            }
Nicolas Patry's avatar
Nicolas Patry committed
572
573
574
            return rope_scaling
        return getattr(config, "rope_scaling", None)

575
    class PositionRotaryEmbedding(nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
576
        def __init__(self, inv_freq, scaling_factor):
577
            super().__init__()
578
            self.inv_freq = inv_freq
579
580
581
582
583
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None
Nicolas Patry's avatar
Nicolas Patry committed
584
585
            self.scaling_factor = scaling_factor
            self.dynamic_args = None
586
587

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
588
589
590
591
592
593
594
595
596
        def static(cls, config, dim, base, device):
            inv_freq = _create_inv_freq(dim, base, device)
            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
597
598
599
600
601
602
603
                    return DynamicPositionRotaryEmbedding(
                        dim=dim,
                        max_position_embeddings=config.max_position_embeddings,
                        base=base,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
604
605
606
607
608
609
610
611
612
613
614
615
616
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=rope_scaling["original_max_position_embeddings"],
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
                        beta_slow=1

                    )
Nicolas Patry's avatar
Nicolas Patry committed
617
                else:
OlivierDehaene's avatar
OlivierDehaene committed
618
619
620
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
621
            return cls(inv_freq, scaling_factor)
622
623

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
624
        def load(cls, config, prefix, weights):
625
626
627
628
629
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
Nicolas Patry's avatar
Nicolas Patry committed
630
631
632
633
634
635
636
637

            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
638
639
640
641
642
643
644
                    return DynamicPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=config.max_position_embeddings,
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
645
646
647
648
649
650
651
652
653
654
655
656
657
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=rope_scaling["original_max_position_embeddings"],
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
                        beta_slow=1

                    )
Nicolas Patry's avatar
Nicolas Patry committed
658
                else:
OlivierDehaene's avatar
OlivierDehaene committed
659
660
661
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
662
            return cls(inv_freq, scaling_factor)
663

664
665
666
667
668
669
670
671
672
673
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Nicolas Patry's avatar
Nicolas Patry committed
674
675
                if self.scaling_factor is not None:
                    t /= self.scaling_factor
676
677
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Nicolas Patry's avatar
Nicolas Patry committed
678

679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
        ):
            """
            Return cos and sin for the asked position ids
            """

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
            return cos.unsqueeze(1), sin.unsqueeze(1)

696
        def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
697
            rotary_dim = cos.shape[-1]
698
699
700
701
702
            x1 = x[..., :rotary_dim]
            x2 = x[..., rotary_dim : 2 * rotary_dim]

            rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
            return x
703

Nicolas Patry's avatar
Nicolas Patry committed
704
705
    class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
Nicolas Patry's avatar
Nicolas Patry committed
706
            inv_freq = _create_inv_freq(dim, base, device)
Nicolas Patry's avatar
Nicolas Patry committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                if seqlen > self.max_position_embeddings:
OlivierDehaene's avatar
OlivierDehaene committed
721
722
723
724
725
726
727
                    newbase = self.base * (
                        (self.scaling_factor * seqlen / self.max_position_embeddings)
                        - (self.scaling_factor - 1)
                    ) ** (self.dim / (self.dim - 2))
                    self.inv_freq = _create_inv_freq(
                        self.dim, newbase, self.inv_freq.device
                    )
Nicolas Patry's avatar
Nicolas Patry committed
728
729
730
731
732
733
734
735
736
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

Nicolas Patry's avatar
Nicolas Patry committed
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807

    # Inverse dim formula to find dim based on number of rotations
    import math
    def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
        return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))

    # Find dim range bounds based on rotations
    def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
        low = math.floor(find_correction_dim(
            low_rot, dim, base, max_position_embeddings))
        high = math.ceil(find_correction_dim(
            high_rot, dim, base, max_position_embeddings))
        return max(low, 0), min(high, dim-1)  # Clamp values just in case

    def linear_ramp_mask(min, max, dim):
        if min == max:
            max += 0.001  # Prevent singularity

        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    def get_mscale(scale=1):
        if scale <= 1:
            return 1.0
        return 0.1 * math.log(scale) + 1.0

    class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
            inv_freq = _create_inv_freq(dim, base, device)
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base
            self.extrapolation_factor = extrapolation_factor
            self.attn_factor = attn_factor
            self.beta_fast = beta_fast
            self.beta_slow = beta_slow
            self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                if seqlen > self.max_position_embeddings:
                    inv_freq_extrapolation = _create_inv_freq(
                        self.dim, self.base, self.inv_freq.device
                    )
                    freqs = 1.0 / inv_freq_extrapolation
                    inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
                    low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
                    inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
                    inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

                    self.inv_freq = inv_freq
                    self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation


                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)

808
809
except ImportError:
    pass