layers.py 28.7 KB
Newer Older
1
import os
2
import torch
3
import torch.distributed
4
5

from torch import nn
6
from torch.nn import functional as F
7
from typing import List
8
9
from loguru import logger
from functools import lru_cache
10
11
12

HAS_BITS_AND_BYTES = True
try:
13
    import bitsandbytes as bnb
Nicolas Patry's avatar
Nicolas Patry committed
14
    from bitsandbytes.nn import Int8Params, Params4bit
15
16

except ImportError:
17
18
    HAS_BITS_AND_BYTES = False

19
20
from accelerate import init_empty_weights

21
from text_generation_server.utils.gptq.quant_linear import QuantLinear
22

23
24

HAS_AWQ = True
OlivierDehaene's avatar
OlivierDehaene committed
25
try:
26
27
28
29
    from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
    HAS_AWQ = False

30
try:
31
32
33
    major, _minor = torch.cuda.get_device_capability()
except Exception:
    major = 1
Nicolas Patry's avatar
Nicolas Patry committed
34

35
36
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
Nicolas Patry's avatar
Nicolas Patry committed
37
38
39
40
41
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
    logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding")
    V2 = False

42
if os.getenv("DISABLE_EXLLAMA") == "True":
43
    HAS_EXLLAMA = False
44
elif CAN_EXLLAMA:
OlivierDehaene's avatar
OlivierDehaene committed
45
    try:
Nicolas Patry's avatar
Nicolas Patry committed
46
47
48
49
50
51
52
53
54
55
56
57
        if V2:
            from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, 
                    create_exllama_buffers,
                    set_device,
                                                                     )
            HAS_EXLLAMA = "2"
        else:
            from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
                    create_exllama_buffers,
                    set_device,
                )
            HAS_EXLLAMA = "1"
OlivierDehaene's avatar
OlivierDehaene committed
58
59
60

    except ImportError:
        pass
61

62
from typing import Optional
63

64
65
66
HAS_EETQ = False
try:
    from EETQ import quant_weights, w8_a16_gemm
OlivierDehaene's avatar
OlivierDehaene committed
67

68
69
70
71
    HAS_EETQ = True
except ImportError:
    pass

72

73
74
75
76
77
78
79
80
81
82
83
84
85
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


86
87
88
89
90
91
92
93
94
95
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = None
    return ln

OlivierDehaene's avatar
OlivierDehaene committed
96

97
98
99
100
101
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
102
103
104
105
106
107
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
108
109
110
111
112
113
114

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = nn.Parameter(bias)
    return conv2d


@classmethod
OlivierDehaene's avatar
OlivierDehaene committed
115
116
117
def load_conv2d_no_bias(
    cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
118
119
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
120
121
122
123
124
125
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
126
127
128
129
130

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = None
    return conv2d

131

132
133
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
134
torch.nn.LayerNorm.load = load_layer_norm
135
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
136

137
138

class FastLinear(nn.Module):
139
140
    def __init__(
        self,
141
142
        weight,
        bias,
143
    ) -> None:
144
145
146
147
148
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
149
            self.bias = None
150
151
152
153
154
155

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
156
        else:
157
158
            bias = None
        return cls(weight, bias)
159
160

    def forward(self, input: torch.Tensor) -> torch.Tensor:
161
        return F.linear(input, self.weight, self.bias)
162
163


164
165
166
167
168
169
170
171
172
173
class EETQLinear(nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        device = weight.device
        weight = torch.t(weight).contiguous().cpu()
        weight, scale = quant_weights(weight, torch.int8, False)
174

175
176
177
178
179
180
181
182
183
184
        self.weight = weight.cuda(device)
        self.scale = scale.cuda(device)
        self.bias = bias.cuda(device) if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = w8_a16_gemm(input, self.weight, self.scale)
        output = output + self.bias if self.bias is not None else output
        return output


185
class Linear8bitLt(nn.Module):
186
187
    def __init__(
        self,
188
189
190
191
192
193
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
194
    ):
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
213
        )
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
241
242


Nicolas Patry's avatar
Nicolas Patry committed
243
244
245
246
class Linear4bit(nn.Module):
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
OlivierDehaene's avatar
OlivierDehaene committed
247
248
249
250
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
Nicolas Patry's avatar
Nicolas Patry committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out


279
280
@lru_cache(1)
def warn_deprecate_bnb():
OlivierDehaene's avatar
OlivierDehaene committed
281
282
283
284
    logger.warning(
        "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
    )

285

286
287
288
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
289
290
291
292
    elif quantize == "eetq":
        if HAS_EETQ:
            linear = EETQLinear(weight, bias)
        else:
OlivierDehaene's avatar
OlivierDehaene committed
293
294
295
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
296
    elif quantize == "bitsandbytes":
297
        warn_deprecate_bnb()
298
299
300
301
302
303
304
305
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
Nicolas Patry's avatar
Nicolas Patry committed
306
307
308
309
310
311
312
313
314
315
316
317
    elif quantize == "bitsandbytes-fp4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
318
    elif quantize == "gptq":
319
        try:
320
            qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
321
322
323
324
325
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

326
        if use_exllama:
Nicolas Patry's avatar
Nicolas Patry committed
327
            linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
328
329
330
331
332
333
334
335
336
337
        else:
            linear = QuantLinear(
                qweight,
                qzeros,
                scales,
                g_idx,
                bias,
                bits,
                groupsize,
            )
338
339
340
341
342
343
344
    elif quantize == "awq":
        try:
            qweight, qzeros, scales, _, bits, groupsize, _ = weight
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `awq` compatible, loader needs to be updated."
            )
OlivierDehaene's avatar
OlivierDehaene committed
345
346
347
348
349
350
351
352
        linear = WQLinear(
            w_bit=bits,
            group_size=groupsize,
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            bias=bias is not None,
        )
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


class TensorParallelHead(SuperLayer):
368
    def __init__(self, linear, process_group, should_gather: bool):
369
        super().__init__(linear)
370
        self.process_group = process_group
371
        self.should_gather = should_gather
372
373
374

    @staticmethod
    def load(config, prefix: str, weights):
375
376
377
378
379
380
381
382
383
384
385
386
        if weights.process_group.size() > 1:
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False
387

388
389
        # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
        if config.quantize in ["gptq", "awq", "eetq"]:
390
391
392
            quantize = None
        else:
            quantize = config.quantize
393
        return TensorParallelHead(
394
            get_linear(weight, bias=None, quantize=quantize),
395
            process_group=weights.process_group,
396
            should_gather=should_gather,
397
398
399
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
OlivierDehaene's avatar
OlivierDehaene committed
400
401
402
        if not self.should_gather:
            return super().forward(input)

403
        world_size = self.process_group.size()
OlivierDehaene's avatar
OlivierDehaene committed
404
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
405
406
            out_dim = self.linear.weight.shape[0]

OlivierDehaene's avatar
OlivierDehaene committed
407
408
409
410
411
412
413
414
            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T
415
416
417
418

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
OlivierDehaene's avatar
OlivierDehaene committed
419
                world_out, gather_input, group=self.process_group
420
421
            )

OlivierDehaene's avatar
OlivierDehaene committed
422
423
424
            if input.shape[0] == 1:
                return world_out
            return world_out.T
425

OlivierDehaene's avatar
OlivierDehaene committed
426
427
428
429
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
430
431
432
433
434
435
436
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
xiaobin's avatar
xiaobin committed
437
438
    def load_qkv(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
OlivierDehaene's avatar
OlivierDehaene committed
439
        weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
xiaobin's avatar
xiaobin committed
440
441
442
443
444
445
446
447
        if bias:
            raise NotImplementedError("packed_qkv only implemented for baichuan")
        else:
            bias = None
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)

    @classmethod
448
    def load(cls, config, prefix: str, weights, bias: bool):
449
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
450

451
452
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
453
454
455
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
456

457
458
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
459
            bias = torch.cat(b, dim=dim)
460
461
        else:
            bias = None
462
463
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
464

465
466
467
468

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
469
470
        self.process_group = process_group

471
472
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
473
474
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

475
476
477
478
479
480
481
482
483
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
484

485
486
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = super().forward(input)
487
488
        if self.process_group.size() > 1:
            torch.distributed.all_reduce(out, group=self.process_group)
489
        return out
490
491


492
493
494
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
495
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
496
497
498
499
500
501
502
503
504
505
506
507
508
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

        block_size = num_embeddings // world_size
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
        self.null_idx = block_size
        self.process_group = weights.process_group
        self.reduce = reduce
509
510

        """Additional 0 entry used for masking"""
511
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
512
513
514
515
516
517
518
519
520

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
521
        out = torch.nn.functional.embedding(input, self.weight)
522
        if self.reduce and self.process_group.size() > 1:
523
            torch.distributed.all_reduce(out, group=self.process_group)
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        return out


try:
    import dropout_layer_norm

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual

except ImportError:
    pass


try:
    from flash_attn.layers.rotary import RotaryEmbedding
    import rotary_emb

Nicolas Patry's avatar
Nicolas Patry committed
573
574
    def _create_inv_freq(dim, base, device):
        inv_freq = 1.0 / (
OlivierDehaene's avatar
OlivierDehaene committed
575
            base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
Nicolas Patry's avatar
Nicolas Patry committed
576
577
578
579
580
        )
        return inv_freq

    def _get_rope_config(config):
        if os.getenv("ROPE_SCALING", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
581
582
583
584
            rope_scaling = {
                "type": os.environ["ROPE_SCALING"],
                "factor": float(os.environ["ROPE_FACTOR"]),
            }
Nicolas Patry's avatar
Nicolas Patry committed
585
586
587
            return rope_scaling
        return getattr(config, "rope_scaling", None)

588
    class PositionRotaryEmbedding(nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
589
        def __init__(self, inv_freq, scaling_factor):
590
            super().__init__()
591
            self.inv_freq = inv_freq
592
593
594
595
596
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None
Nicolas Patry's avatar
Nicolas Patry committed
597
598
            self.scaling_factor = scaling_factor
            self.dynamic_args = None
599
600

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
601
602
603
604
605
606
607
608
609
        def static(cls, config, dim, base, device):
            inv_freq = _create_inv_freq(dim, base, device)
            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
610
611
612
613
614
615
616
                    return DynamicPositionRotaryEmbedding(
                        dim=dim,
                        max_position_embeddings=config.max_position_embeddings,
                        base=base,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
617
618
619
620
621
622
623
624
625
626
627
628
629
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=rope_scaling["original_max_position_embeddings"],
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
                        beta_slow=1

                    )
Nicolas Patry's avatar
Nicolas Patry committed
630
                else:
OlivierDehaene's avatar
OlivierDehaene committed
631
632
633
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
634
            return cls(inv_freq, scaling_factor)
635
636

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
637
        def load(cls, config, prefix, weights):
638
639
640
641
642
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
Nicolas Patry's avatar
Nicolas Patry committed
643
644
645
646
647
648
649
650

            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
651
652
653
654
655
656
657
                    return DynamicPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=config.max_position_embeddings,
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
658
659
660
661
662
663
664
665
666
667
668
669
670
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=rope_scaling["original_max_position_embeddings"],
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
                        beta_slow=1

                    )
Nicolas Patry's avatar
Nicolas Patry committed
671
                else:
OlivierDehaene's avatar
OlivierDehaene committed
672
673
674
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
675
            return cls(inv_freq, scaling_factor)
676

677
678
679
680
681
682
683
684
685
686
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Nicolas Patry's avatar
Nicolas Patry committed
687
688
                if self.scaling_factor is not None:
                    t /= self.scaling_factor
689
690
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Nicolas Patry's avatar
Nicolas Patry committed
691

692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
        ):
            """
            Return cos and sin for the asked position ids
            """

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
            return cos.unsqueeze(1), sin.unsqueeze(1)

709
        def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
710
            rotary_dim = cos.shape[-1]
711
712
713
714
715
            x1 = x[..., :rotary_dim]
            x2 = x[..., rotary_dim : 2 * rotary_dim]

            rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
            return x
716

Nicolas Patry's avatar
Nicolas Patry committed
717
718
    class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
Nicolas Patry's avatar
Nicolas Patry committed
719
            inv_freq = _create_inv_freq(dim, base, device)
Nicolas Patry's avatar
Nicolas Patry committed
720
721
722
723
724
725
726
727
728
729
730
731
732
733
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                if seqlen > self.max_position_embeddings:
OlivierDehaene's avatar
OlivierDehaene committed
734
735
736
737
738
739
740
                    newbase = self.base * (
                        (self.scaling_factor * seqlen / self.max_position_embeddings)
                        - (self.scaling_factor - 1)
                    ) ** (self.dim / (self.dim - 2))
                    self.inv_freq = _create_inv_freq(
                        self.dim, newbase, self.inv_freq.device
                    )
Nicolas Patry's avatar
Nicolas Patry committed
741
742
743
744
745
746
747
748
749
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

Nicolas Patry's avatar
Nicolas Patry committed
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820

    # Inverse dim formula to find dim based on number of rotations
    import math
    def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
        return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))

    # Find dim range bounds based on rotations
    def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
        low = math.floor(find_correction_dim(
            low_rot, dim, base, max_position_embeddings))
        high = math.ceil(find_correction_dim(
            high_rot, dim, base, max_position_embeddings))
        return max(low, 0), min(high, dim-1)  # Clamp values just in case

    def linear_ramp_mask(min, max, dim):
        if min == max:
            max += 0.001  # Prevent singularity

        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    def get_mscale(scale=1):
        if scale <= 1:
            return 1.0
        return 0.1 * math.log(scale) + 1.0

    class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
            inv_freq = _create_inv_freq(dim, base, device)
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base
            self.extrapolation_factor = extrapolation_factor
            self.attn_factor = attn_factor
            self.beta_fast = beta_fast
            self.beta_slow = beta_slow
            self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                if seqlen > self.max_position_embeddings:
                    inv_freq_extrapolation = _create_inv_freq(
                        self.dim, self.base, self.inv_freq.device
                    )
                    freqs = 1.0 / inv_freq_extrapolation
                    inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
                    low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
                    inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
                    inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

                    self.inv_freq = inv_freq
                    self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation


                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)

821
822
except ImportError:
    pass