layers.py 44.7 KB
Newer Older
1
import os
2
import torch
3
import torch.distributed
4
5

from torch import nn
6
from torch.nn import functional as F
7
from typing import List, Tuple, Optional
8
9
from loguru import logger
from functools import lru_cache
10
11
12

HAS_BITS_AND_BYTES = True
try:
13
    import bitsandbytes as bnb
Nicolas Patry's avatar
Nicolas Patry committed
14
    from bitsandbytes.nn import Int8Params, Params4bit
15
except ImportError:
16
17
    HAS_BITS_AND_BYTES = False

18
19
from accelerate import init_empty_weights

20
from text_generation_server.utils.gptq.quant_linear import QuantLinear
21
22
23
24
25
26
27
28
from text_generation_server.utils.import_utils import (
    IS_CUDA_SYSTEM,
    IS_ROCM_SYSTEM,
    IS_XPU_SYSTEM,
)

if IS_XPU_SYSTEM:
    import intel_extension_for_pytorch as ipex
29
30

HAS_AWQ = True
OlivierDehaene's avatar
OlivierDehaene committed
31
try:
32
33
34
35
    from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
    HAS_AWQ = False

36
try:
37
38
39
    major, _minor = torch.cuda.get_device_capability()
except Exception:
    major = 1
Nicolas Patry's avatar
Nicolas Patry committed
40

41
HAS_EXLLAMA = False
fxmarty's avatar
fxmarty committed
42
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
Nicolas Patry's avatar
Nicolas Patry committed
43
44
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"

45
if os.getenv("DISABLE_EXLLAMA") == "True":
46
    HAS_EXLLAMA = False
47
elif CAN_EXLLAMA:
OlivierDehaene's avatar
OlivierDehaene committed
48
    try:
Nicolas Patry's avatar
Nicolas Patry committed
49
        if V2:
OlivierDehaene's avatar
OlivierDehaene committed
50
51
52
53
54
            from text_generation_server.utils.gptq.exllamav2 import (
                QuantLinear as ExllamaQuantLinear,
                create_exllama_buffers,
                set_device,
            )
OlivierDehaene's avatar
OlivierDehaene committed
55

Nicolas Patry's avatar
Nicolas Patry committed
56
57
            HAS_EXLLAMA = "2"
        else:
OlivierDehaene's avatar
OlivierDehaene committed
58
59
60
61
62
            from text_generation_server.utils.gptq.exllama import (
                Ex4bitLinear as ExllamaQuantLinear,
                create_exllama_buffers,
                set_device,
            )
OlivierDehaene's avatar
OlivierDehaene committed
63

Nicolas Patry's avatar
Nicolas Patry committed
64
            HAS_EXLLAMA = "1"
OlivierDehaene's avatar
OlivierDehaene committed
65
66
67

    except ImportError:
        pass
68

69
70
71
HAS_EETQ = False
try:
    from EETQ import quant_weights, w8_a16_gemm
OlivierDehaene's avatar
OlivierDehaene committed
72

73
74
75
76
    HAS_EETQ = True
except ImportError:
    pass

77

78
79
80
81
82
83
84
85
86
87
88
89
90
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


91
92
93
94
95
96
97
98
99
100
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = None
    return ln

OlivierDehaene's avatar
OlivierDehaene committed
101

102
103
104
105
106
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
107
108
109
110
111
112
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
113
114
115
116
117
118
119

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = nn.Parameter(bias)
    return conv2d


@classmethod
OlivierDehaene's avatar
OlivierDehaene committed
120
def load_conv2d_no_bias(
OlivierDehaene's avatar
OlivierDehaene committed
121
    cls, prefix, weights, in_channels, out_channels, kernel_size, stride
OlivierDehaene's avatar
OlivierDehaene committed
122
):
123
124
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
125
126
127
128
129
130
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
131
132
133
134
135

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = None
    return conv2d

136

137
138
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
139
torch.nn.LayerNorm.load = load_layer_norm
140
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
141

142
143

class FastLinear(nn.Module):
144
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
145
146
147
        self,
        weight,
        bias,
148
    ) -> None:
149
150
151
152
153
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
154
            self.bias = None
155
156
157
158
159
160

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
161
        else:
162
163
            bias = None
        return cls(weight, bias)
164
165

    def forward(self, input: torch.Tensor) -> torch.Tensor:
166
        return F.linear(input, self.weight, self.bias)
167
168


169
170
class EETQLinear(nn.Module):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
171
172
173
        self,
        weight,
        bias,
174
175
176
    ) -> None:
        super().__init__()
        device = weight.device
177
178
        if weight.dtype != torch.float16:
            weight = weight.to(dtype=torch.float16)
179
180
        weight = torch.t(weight).contiguous().cpu()
        weight, scale = quant_weights(weight, torch.int8, False)
181

182
183
184
185
186
187
188
189
190
191
        self.weight = weight.cuda(device)
        self.scale = scale.cuda(device)
        self.bias = bias.cuda(device) if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = w8_a16_gemm(input, self.weight, self.scale)
        output = output + self.bias if self.bias is not None else output
        return output


Nicolas Patry's avatar
Nicolas Patry committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
    device = weight.device
    # weight, scale = quant_weights(weight, torch.int8, False)
    finfo = torch.finfo(qdtype)
    # Calculate the scale as dtype max divided by absmax
    scale = finfo.max / weight.abs().max().clamp(min=1e-12)
    # scale and clamp the tensor to bring it to
    # the representative range of float8 data type
    # (as default cast is unsaturated)
    qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
    scale = scale.float().reciprocal()
    return qweight, scale


class Fp8Linear(nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        self.dtype = weight.dtype
        self.qweight, self.scale = fp8_quantize(weight)

        self.bias = bias if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        qinput, scale = fp8_quantize(input)
        output, _ = torch._scaled_mm(
            qinput,
            self.qweight.t(),
            out_dtype=self.dtype,
            scale_a=scale,
            scale_b=self.scale,
            bias=self.bias,
        )
        return output


234
class Linear8bitLt(nn.Module):
235
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
236
237
238
239
240
241
242
        self,
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
243
    ):
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
262
        )
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
290
291


Nicolas Patry's avatar
Nicolas Patry committed
292
293
294
295
class Linear4bit(nn.Module):
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
OlivierDehaene's avatar
OlivierDehaene committed
296
297
298
299
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
Nicolas Patry's avatar
Nicolas Patry committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out


328
329
@lru_cache(1)
def warn_deprecate_bnb():
OlivierDehaene's avatar
OlivierDehaene committed
330
331
332
333
    logger.warning(
        "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
    )

334

335
336
337
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
338
339
340
341
    elif quantize == "eetq":
        if HAS_EETQ:
            linear = EETQLinear(weight, bias)
        else:
OlivierDehaene's avatar
OlivierDehaene committed
342
343
344
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
Nicolas Patry's avatar
Nicolas Patry committed
345
346
    elif quantize == "fp8":
        linear = Fp8Linear(weight, bias)
347
    elif quantize == "bitsandbytes":
348
        warn_deprecate_bnb()
349
350
351
352
353
354
355
356
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
Nicolas Patry's avatar
Nicolas Patry committed
357
358
359
360
361
362
363
364
365
366
367
368
    elif quantize == "bitsandbytes-fp4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
369
    elif quantize == "gptq":
370
        try:
371
            qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
372
373
374
375
376
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

377
        if use_exllama:
OlivierDehaene's avatar
OlivierDehaene committed
378
379
380
            linear = ExllamaQuantLinear(
                qweight, qzeros, scales, g_idx, bias, bits, groupsize
            )
381
382
383
384
385
386
387
388
389
390
        else:
            linear = QuantLinear(
                qweight,
                qzeros,
                scales,
                g_idx,
                bias,
                bits,
                groupsize,
            )
391
392
393
394
395
396
397
    elif quantize == "awq":
        try:
            qweight, qzeros, scales, _, bits, groupsize, _ = weight
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `awq` compatible, loader needs to be updated."
            )
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
398
399
400
401
402
403
        if IS_ROCM_SYSTEM:
            raise NotImplementedError(
                "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
                "to use Exllama/GPTQ kernels for AWQ inference."
            )
        if not HAS_AWQ:
OlivierDehaene's avatar
OlivierDehaene committed
404
405
406
            raise NotImplementedError(
                "You do not seem to have awq installed, either install it (cd server &&  make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
            )
OlivierDehaene's avatar
OlivierDehaene committed
407
408
409
410
411
412
413
414
        linear = WQLinear(
            w_bit=bits,
            group_size=groupsize,
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            bias=bias is not None,
        )
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


429
430
431
432
433
434
435
436
437
438
439
440
441
class ResBlock(torch.nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        self.linear = FastLinear.load(
            config, prefix=f"{prefix}.linear", weights=weights, bias=True
        )
        self.act = torch.nn.SiLU()

    def forward(self, x):
        return x + self.act(self.linear(x))


class MedusaModel(torch.nn.Module):
OlivierDehaene's avatar
OlivierDehaene committed
442
    def __init__(self, config, medusa_config, weights):
443
444
445
        super().__init__()
        self.heads = torch.nn.ModuleList(
            [
OlivierDehaene's avatar
OlivierDehaene committed
446
447
                MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
                for i in range(medusa_config["medusa_num_heads"])
448
449
450
451
452
453
454
455
456
            ]
        )

    def forward(self, x):
        speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
        return speculative_logits


class MedusaHead(torch.nn.Module):
OlivierDehaene's avatar
OlivierDehaene committed
457
    def __init__(self, config, medusa_config, prefix, weights):
458
459
460
461
        super().__init__()
        self.blocks = torch.nn.ModuleList(
            [
                ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
OlivierDehaene's avatar
OlivierDehaene committed
462
                for i in range(medusa_config["medusa_num_layers"])
463
464
465
466
467
468
469
470
471
472
473
474
475
476
            ]
        )
        n = len(self.blocks)
        self.out = FastLinear.load(
            config, prefix=f"{prefix}.{n}", weights=weights, bias=False
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.out(x)
        return x


OlivierDehaene's avatar
OlivierDehaene committed
477
class MedusaHeadV1(nn.Module):
478
479
480
481
482
483
484
    def __init__(self, lm_head, medusa):
        super().__init__()
        self.lm_head = lm_head
        self.medusa = medusa

    @staticmethod
    def load(config, prefix: str, weights):
OlivierDehaene's avatar
OlivierDehaene committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        from pathlib import Path
        from safetensors import safe_open
        import json

        use_medusa = config.use_medusa

        medusa_config = str(Path(use_medusa) / "config.json")
        filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

        with open(medusa_config, "r") as f:
            medusa_config = json.load(f)
        routing = weights.routing
        with safe_open(filename, framework="pytorch") as f:
            for k in f.keys():
                if k in routing and routing[k] != filename:
                    raise RuntimeError(
                        f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                    )
                routing[k] = filename

        medusa = MedusaModel(config, medusa_config, weights)
506
        lm_head = TensorParallelHead.load(config, prefix, weights)
OlivierDehaene's avatar
OlivierDehaene committed
507
508
509
510
511
512
        return MedusaHeadV1(lm_head, medusa)

    def forward(
        self, input: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        logits = self.lm_head(input)
OlivierDehaene's avatar
OlivierDehaene committed
513
514
515
516
        # If we have too many tokens, we skip speculative logits
        if input.shape[0] > 128:
            return logits, None

OlivierDehaene's avatar
OlivierDehaene committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        speculative_logits = self.medusa(input)
        return logits, speculative_logits


class MedusaHeadV2(nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        from pathlib import Path
        from safetensors import safe_open
        import json

        use_medusa = config.use_medusa

        medusa_config = str(Path(use_medusa) / "config.json")
        filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

        with open(medusa_config, "r") as f:
            medusa_config = json.load(f)
        routing = weights.routing
        with safe_open(filename, framework="pytorch") as f:
            for k in f.keys():
                if k in routing and routing[k] != filename:
                    raise RuntimeError(
                        f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                    )
                routing[k] = filename

        self.n_medusa_heads = medusa_config["medusa_num_heads"]

        assert medusa_config["medusa_num_layers"] == 1
        self.linear = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
            dim=0,
            weights=weights,
            bias=True,
        )
        self.process_group = weights.process_group
        self.world_size = self.process_group.size()
        self.rank = self.process_group.rank()

        self.act = torch.nn.SiLU()

        self.lm_head = TensorParallelHead.load(config, prefix, weights)

    def forward(self, x):
OlivierDehaene's avatar
OlivierDehaene committed
563
564
565
566
567
        # If we have too many tokens, we skip speculative logits
        if x.shape[0] > 128:
            logits = self.lm_head(x)
            return logits, None

OlivierDehaene's avatar
OlivierDehaene committed
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        size = x.shape[-1]
        block_size = (size + self.world_size - 1) // self.world_size
        start = self.rank * block_size
        stop = (self.rank + 1) * block_size

        x_block = x[:, start:stop]

        # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
        medusa_res = self.act(self.linear(x)).reshape(
            *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
        )

        # Apply all residual medusa heads
        output = x[:, start:stop].unsqueeze(-2) + medusa_res

        # Gather medusa heads
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)

        # Stack x and medusa residual x
        stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)

        # Compute lm head on x + medusa residual x
        logits = self.lm_head(stacked_x)

        # Finally, split logits from speculative logits
        logits, speculative_logits = torch.split(
            logits, [1, self.n_medusa_heads], dim=-2
        )
        # Squeeze added dimension
        logits = logits.squeeze(-2)

        return logits, speculative_logits


class SpeculativeHead(nn.Module):
    def __init__(self, lm_head, medusa):
        super().__init__()
        self.head = lm_head
        self.medusa = medusa

    @staticmethod
    def load(config, prefix: str, weights):
614
615
        use_medusa = config.use_medusa
        if use_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
616
617
618
619
620
            lm_head = None
            try:
                medusa = MedusaHeadV1.load(config, prefix, weights)
            except:
                medusa = MedusaHeadV2(config, prefix, weights)
621
        else:
OlivierDehaene's avatar
OlivierDehaene committed
622
            lm_head = TensorParallelHead.load(config, prefix, weights)
623
624
625
626
627
628
            medusa = None
        return SpeculativeHead(lm_head, medusa)

    def forward(
        self, input: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
OlivierDehaene's avatar
OlivierDehaene committed
629
630
631
632
633
634
        if self.medusa is not None:
            return self.medusa(input)

        assert self.head is not None
        logits = self.head(input)
        return logits, None
635
636


637
class TensorParallelHead(SuperLayer):
638
    def __init__(self, linear, process_group, should_gather: bool):
639
        super().__init__(linear)
640
        self.process_group = process_group
641
        self.should_gather = should_gather
642
643
644

    @staticmethod
    def load(config, prefix: str, weights):
645
646
647
648
649
650
651
652
653
654
655
656
        if weights.process_group.size() > 1:
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False
657

658
659
        # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
        if config.quantize in ["gptq", "awq", "eetq"]:
660
661
662
            quantize = None
        else:
            quantize = config.quantize
663
        return TensorParallelHead(
664
            get_linear(weight, bias=None, quantize=quantize),
665
            process_group=weights.process_group,
666
            should_gather=should_gather,
667
668
669
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
OlivierDehaene's avatar
OlivierDehaene committed
670
671
672
        if not self.should_gather:
            return super().forward(input)

673
        world_size = self.process_group.size()
OlivierDehaene's avatar
OlivierDehaene committed
674
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
675
676
            out_dim = self.linear.weight.shape[0]

OlivierDehaene's avatar
OlivierDehaene committed
677
678
679
680
681
682
683
684
            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T
685
686
687
688

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
OlivierDehaene's avatar
OlivierDehaene committed
689
                world_out, gather_input, group=self.process_group
690
691
            )

OlivierDehaene's avatar
OlivierDehaene committed
692
693
694
            if input.shape[0] == 1:
                return world_out
            return world_out.T
695

OlivierDehaene's avatar
OlivierDehaene committed
696
697
698
699
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
700
701
702
703
704
705
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
Nicolas Patry's avatar
Nicolas Patry committed
706
707
708
709
710
711
712
713
714
715
716
717
718
    @classmethod
    def load_gate_up(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
        weight = weights.get_weights_col_packed_gate_up(
            prefix, quantize=config.quantize
        )
        if bias:
            raise NotImplementedError("packed_gate_up only implemented without bias")
        else:
            bias = None
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)

719
    @classmethod
xiaobin's avatar
xiaobin committed
720
721
    def load_qkv(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
OlivierDehaene's avatar
OlivierDehaene committed
722
        weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
xiaobin's avatar
xiaobin committed
723
724
725
726
727
728
729
730
        if bias:
            raise NotImplementedError("packed_qkv only implemented for baichuan")
        else:
            bias = None
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)

    @classmethod
731
    def load(cls, config, prefix: str, weights, bias: bool):
732
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
733

734
735
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
736
737
738
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
739

740
741
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
742
            bias = torch.cat(b, dim=dim)
743
744
        else:
            bias = None
745
746
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
747

748
749
750
751

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
752
753
        self.process_group = process_group

754
755
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
756
757
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

758
759
760
761
762
763
764
765
766
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
767

768
    def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
769
        out = super().forward(input)
770
        if self.process_group.size() > 1 and reduce:
771
            torch.distributed.all_reduce(out, group=self.process_group)
772
        return out
773
774


775
776
777
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
778
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
779
780
781
782
783
784
785
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

786
        block_size = (num_embeddings + world_size - 1) // world_size
787
788
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
OlivierDehaene's avatar
OlivierDehaene committed
789
790
791
        self.null_idx = weight.shape[
            0
        ]  # Usually block_size, might be less in non even vocab_size.
792
793
        self.process_group = weights.process_group
        self.reduce = reduce
794
795

        """Additional 0 entry used for masking"""
796
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
797
798
799
800
801
802
803
804
805

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
806
        out = torch.nn.functional.embedding(input, self.weight)
807
        if self.reduce and self.process_group.size() > 1:
808
            torch.distributed.all_reduce(out, group=self.process_group)
809
810
811
812
        return out


try:
fxmarty's avatar
fxmarty committed
813
814
    if IS_CUDA_SYSTEM:
        import dropout_layer_norm
OlivierDehaene's avatar
OlivierDehaene committed
815
816
    elif IS_ROCM_SYSTEM:
        from vllm import layernorm_ops
fxmarty's avatar
fxmarty committed
817
818
    else:
        dropout_layer_norm = None
819
820
821

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
822
823
824
825
826
827
828
829
830
            if IS_XPU_SYSTEM:
                res_out = hidden_states
                out = ipex.llm.functional.add_layer_norm(
                    residual, hidden_states, self.weight, self.bias, self.eps, True
                )
                if residual is not None:
                    res_out = residual
                return out, res_out
            elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual
OlivierDehaene's avatar
OlivierDehaene committed
862
863
864
865
866
867
868
869
870
871
872
873
874
875

    class FastRMSNorm(nn.Module):
        def __init__(self, weight: torch.Tensor, eps: float):
            super().__init__()

            self.weight = nn.Parameter(weight)
            self.variance_epsilon = eps

        @classmethod
        def load(cls, prefix, weights, eps=1e-6):
            weight = weights.get_tensor(f"{prefix}.weight")
            return cls(weight, eps)

        def forward(self, hidden_states, residual=None):
876
877
878
879
880
881
882
883
884
885
886
887
888
889
            if IS_XPU_SYSTEM:
                residual_out = hidden_states
                out = ipex.llm.functional.add_rms_norm(
                    residual,
                    hidden_states,
                    self.weight,
                    None,
                    self.variance_epsilon,
                    True,
                )
                if residual is not None:
                    residual_out = residual
                return out, residual_out
            elif hidden_states.shape[-1] > 8192:
OlivierDehaene's avatar
OlivierDehaene committed
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                hidden_states = hidden_states.to(torch.float32)
                variance = hidden_states.pow(2).mean(-1, keepdim=True)
                hidden_states = hidden_states * torch.rsqrt(
                    variance + self.variance_epsilon
                )

                # convert into half-precision if necessary
                if self.weight.dtype in [torch.float16, torch.bfloat16]:
                    hidden_states = hidden_states.to(self.weight.dtype)

                return self.weight * hidden_states, residual
            elif IS_CUDA_SYSTEM:
                # faster post attention rms norm
OlivierDehaene's avatar
OlivierDehaene committed
907
908
909
910
911
                (
                    normed_hidden_states,
                    res,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
OlivierDehaene's avatar
OlivierDehaene committed
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
                    hidden_states,
                    residual,
                    self.weight,
                    None,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.variance_epsilon,
                    1.0,
                    0,
                    None,
                    False,
                    True,  # Activate RMSNorm
                )
                if res is None:
                    res = hidden_states

                return normed_hidden_states, res
            elif IS_ROCM_SYSTEM:
                # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                out = torch.empty_like(hidden_states)
                layernorm_ops.rms_norm(
                    out,
                    hidden_states,
                    self.weight.data,
                    self.variance_epsilon,
                )
                return out, residual
            else:
                raise ValueError(
OlivierDehaene's avatar
OlivierDehaene committed
948
949
                    "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
                )
OlivierDehaene's avatar
OlivierDehaene committed
950

951
952
953
954
except ImportError:
    pass

try:
fxmarty's avatar
fxmarty committed
955
956
957
958
959
    if IS_CUDA_SYSTEM:
        from flash_attn.layers.rotary import RotaryEmbedding
        import rotary_emb
    elif IS_ROCM_SYSTEM:
        from vllm import pos_encoding_ops
960

Nicolas Patry's avatar
Nicolas Patry committed
961
962
    def _create_inv_freq(dim, base, device):
        inv_freq = 1.0 / (
OlivierDehaene's avatar
OlivierDehaene committed
963
            base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
Nicolas Patry's avatar
Nicolas Patry committed
964
965
966
967
968
        )
        return inv_freq

    def _get_rope_config(config):
        if os.getenv("ROPE_SCALING", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
969
970
971
972
            rope_scaling = {
                "type": os.environ["ROPE_SCALING"],
                "factor": float(os.environ["ROPE_FACTOR"]),
            }
Nicolas Patry's avatar
Nicolas Patry committed
973
974
975
            return rope_scaling
        return getattr(config, "rope_scaling", None)

976
    class PositionRotaryEmbedding(nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
977
        def __init__(self, inv_freq, scaling_factor):
978
            super().__init__()
979
            self.inv_freq = inv_freq
980
981
982
983
984
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None
Nicolas Patry's avatar
Nicolas Patry committed
985
986
            self.scaling_factor = scaling_factor
            self.dynamic_args = None
987

OlivierDehaene's avatar
OlivierDehaene committed
988
989
990
991
992
993
994
        def forward(
            self,
            query: torch.Tensor,
            key: torch.Tensor,
            cos: torch.Tensor,
            sin: torch.Tensor,
        ):
fxmarty's avatar
fxmarty committed
995
996
997
998
            # Such controlflows may add some overhead.
            if IS_CUDA_SYSTEM:
                rotary_dim = cos.shape[-1]
                q1 = query[..., :rotary_dim]
OlivierDehaene's avatar
OlivierDehaene committed
999
                q2 = query[..., rotary_dim : 2 * rotary_dim]
fxmarty's avatar
fxmarty committed
1000
1001
1002
1003

                rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)

                k1 = key[..., :rotary_dim]
OlivierDehaene's avatar
OlivierDehaene committed
1004
                k2 = key[..., rotary_dim : 2 * rotary_dim]
fxmarty's avatar
fxmarty committed
1005
1006
1007
1008
1009
1010
1011
1012
1013

                rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
            elif IS_ROCM_SYSTEM:
                # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
                # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

                head_size = query.shape[-1]

                # Inplace operation, updating query and key.
OlivierDehaene's avatar
OlivierDehaene committed
1014
                pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
1015
1016
1017
1018
            elif IS_XPU_SYSTEM:
                ipex.llm.functional.rotary_embedding(
                    query, key, sin, cos, query.size(-1), True
                )
fxmarty's avatar
fxmarty committed
1019
            else:
OlivierDehaene's avatar
OlivierDehaene committed
1020
                raise ValueError(
OlivierDehaene's avatar
OlivierDehaene committed
1021
1022
                    "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
                )
fxmarty's avatar
fxmarty committed
1023

1024
        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
        def static(cls, config, dim, base, device):
            inv_freq = _create_inv_freq(dim, base, device)
            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
1034
1035
1036
1037
1038
1039
1040
                    return DynamicPositionRotaryEmbedding(
                        dim=dim,
                        max_position_embeddings=config.max_position_embeddings,
                        base=base,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1041
1042
1043
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
OlivierDehaene's avatar
OlivierDehaene committed
1044
1045
1046
                        max_position_embeddings=rope_scaling[
                            "original_max_position_embeddings"
                        ],
Nicolas Patry's avatar
Nicolas Patry committed
1047
1048
1049
1050
1051
1052
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
OlivierDehaene's avatar
OlivierDehaene committed
1053
                        beta_slow=1,
Nicolas Patry's avatar
Nicolas Patry committed
1054
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1055
                else:
OlivierDehaene's avatar
OlivierDehaene committed
1056
1057
1058
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1059
            return cls(inv_freq, scaling_factor)
1060
1061

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
1062
        def load(cls, config, prefix, weights):
1063
1064
1065
1066
1067
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
Nicolas Patry's avatar
Nicolas Patry committed
1068
1069
1070
1071
1072
1073
1074
1075

            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
1076
1077
1078
1079
1080
1081
1082
                    return DynamicPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=config.max_position_embeddings,
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1083
1084
1085
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
OlivierDehaene's avatar
OlivierDehaene committed
1086
1087
1088
                        max_position_embeddings=rope_scaling[
                            "original_max_position_embeddings"
                        ],
Nicolas Patry's avatar
Nicolas Patry committed
1089
1090
1091
1092
1093
1094
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
OlivierDehaene's avatar
OlivierDehaene committed
1095
                        beta_slow=1,
Nicolas Patry's avatar
Nicolas Patry committed
1096
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1097
                else:
OlivierDehaene's avatar
OlivierDehaene committed
1098
1099
1100
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1101
            return cls(inv_freq, scaling_factor)
1102

1103
1104
1105
1106
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
1107
1108
1109
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
1110
1111
1112
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Nicolas Patry's avatar
Nicolas Patry committed
1113
1114
                if self.scaling_factor is not None:
                    t /= self.scaling_factor
1115
1116
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Nicolas Patry's avatar
Nicolas Patry committed
1117

1118
1119
1120
1121
1122
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
OlivierDehaene's avatar
OlivierDehaene committed
1123
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
1124
1125
1126
1127
        ):
            """
            Return cos and sin for the asked position ids
            """
fxmarty's avatar
fxmarty committed
1128
1129
1130
1131
1132
            if IS_ROCM_SYSTEM:
                # For RoCm, we always use float cos/sin to avoid a cast.
                # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
                # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
                dtype = torch.float32
1133
1134
1135
1136
1137

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
1138

fxmarty's avatar
fxmarty committed
1139
            # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
1140
1141
            return cos.unsqueeze(1), sin.unsqueeze(1)

Nicolas Patry's avatar
Nicolas Patry committed
1142
1143
    class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
Nicolas Patry's avatar
Nicolas Patry committed
1144
            inv_freq = _create_inv_freq(dim, base, device)
Nicolas Patry's avatar
Nicolas Patry committed
1145
1146
1147
1148
1149
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base

OlivierDehaene's avatar
OlivierDehaene committed
1150
        def _update_cos_sin_cache(self, dtype, device, seqlen):
Nicolas Patry's avatar
Nicolas Patry committed
1151
1152
1153
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
1154
1155
1156
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
Nicolas Patry's avatar
Nicolas Patry committed
1157
1158
            ):
                if seqlen > self.max_position_embeddings:
OlivierDehaene's avatar
OlivierDehaene committed
1159
                    newbase = self.base * (
OlivierDehaene's avatar
OlivierDehaene committed
1160
1161
                        (self.scaling_factor * seqlen / self.max_position_embeddings)
                        - (self.scaling_factor - 1)
OlivierDehaene's avatar
OlivierDehaene committed
1162
1163
1164
1165
                    ) ** (self.dim / (self.dim - 2))
                    self.inv_freq = _create_inv_freq(
                        self.dim, newbase, self.inv_freq.device
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1166
1167
1168
1169
1170
1171
1172
1173
1174
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

Nicolas Patry's avatar
Nicolas Patry committed
1175
1176
    # Inverse dim formula to find dim based on number of rotations
    import math
OlivierDehaene's avatar
OlivierDehaene committed
1177

OlivierDehaene's avatar
OlivierDehaene committed
1178
1179
1180
1181
1182
1183
    def find_correction_dim(
        num_rotations, dim, base=10000, max_position_embeddings=2048
    ):
        return (
            dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
        ) / (2 * math.log(base))
Nicolas Patry's avatar
Nicolas Patry committed
1184
1185

    # Find dim range bounds based on rotations
OlivierDehaene's avatar
OlivierDehaene committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
    def find_correction_range(
        low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
    ):
        low = math.floor(
            find_correction_dim(low_rot, dim, base, max_position_embeddings)
        )
        high = math.ceil(
            find_correction_dim(high_rot, dim, base, max_position_embeddings)
        )
OlivierDehaene's avatar
OlivierDehaene committed
1195
1196
        return max(low, 0), min(high, dim - 1)  # Clamp values just in case

Nicolas Patry's avatar
Nicolas Patry committed
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
    def linear_ramp_mask(min, max, dim):
        if min == max:
            max += 0.001  # Prevent singularity

        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    def get_mscale(scale=1):
        if scale <= 1:
            return 1.0
        return 0.1 * math.log(scale) + 1.0

    class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
OlivierDehaene's avatar
OlivierDehaene committed
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        def __init__(
            self,
            dim,
            max_position_embeddings,
            base,
            device,
            scaling_factor,
            *,
            extrapolation_factor,
            attn_factor,
            beta_fast,
            beta_slow,
        ):
Nicolas Patry's avatar
Nicolas Patry committed
1224
1225
1226
1227
1228
1229
1230
1231
1232
            inv_freq = _create_inv_freq(dim, base, device)
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base
            self.extrapolation_factor = extrapolation_factor
            self.attn_factor = attn_factor
            self.beta_fast = beta_fast
            self.beta_slow = beta_slow
OlivierDehaene's avatar
OlivierDehaene committed
1233
1234
1235
            self.mscale = float(
                get_mscale(self.scaling_factor) * self.attn_factor
            )  # Get n-d magnitude scaling corrected for interpolation
Nicolas Patry's avatar
Nicolas Patry committed
1236
1237
1238
1239
1240

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
1241
1242
1243
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
Nicolas Patry's avatar
Nicolas Patry committed
1244
1245
1246
1247
1248
1249
1250
            ):
                if seqlen > self.max_position_embeddings:
                    inv_freq_extrapolation = _create_inv_freq(
                        self.dim, self.base, self.inv_freq.device
                    )
                    freqs = 1.0 / inv_freq_extrapolation
                    inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
OlivierDehaene's avatar
OlivierDehaene committed
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
                    low, high = find_correction_range(
                        self.beta_fast,
                        self.beta_slow,
                        self.dim,
                        self.base,
                        self.max_position_embeddings,
                    )
                    inv_freq_mask = (
                        1
                        - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
                    ) * self.extrapolation_factor  # Get n-d rotational scaling corrected for extrapolation
                    inv_freq = (
                        inv_freq_interpolation * (1 - inv_freq_mask)
                        + inv_freq_extrapolation * inv_freq_mask
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1266
1267

                    self.inv_freq = inv_freq
OlivierDehaene's avatar
OlivierDehaene committed
1268
1269
1270
                    self.mscale = float(
                        get_mscale(self.scaling_factor) * self.attn_factor
                    )  # Get n-d magnitude scaling corrected for interpolation
Nicolas Patry's avatar
Nicolas Patry committed
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280

                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)

1281
1282
except ImportError:
    pass