layers.py 42.8 KB
Newer Older
1
import os
2
import torch
3
import torch.distributed
4
5

from torch import nn
6
from torch.nn import functional as F
7
from typing import List, Tuple, Optional
8
9
from loguru import logger
from functools import lru_cache
10
11
12

HAS_BITS_AND_BYTES = True
try:
13
    import bitsandbytes as bnb
Nicolas Patry's avatar
Nicolas Patry committed
14
    from bitsandbytes.nn import Int8Params, Params4bit
15
except ImportError:
16
17
    HAS_BITS_AND_BYTES = False

18
19
from accelerate import init_empty_weights

20
from text_generation_server.utils.gptq.quant_linear import QuantLinear
OlivierDehaene's avatar
OlivierDehaene committed
21
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
22
23

HAS_AWQ = True
OlivierDehaene's avatar
OlivierDehaene committed
24
try:
25
26
27
28
    from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
    HAS_AWQ = False

29
try:
30
31
32
    major, _minor = torch.cuda.get_device_capability()
except Exception:
    major = 1
Nicolas Patry's avatar
Nicolas Patry committed
33

34
HAS_EXLLAMA = False
fxmarty's avatar
fxmarty committed
35
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
Nicolas Patry's avatar
Nicolas Patry committed
36
37
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"

38
if os.getenv("DISABLE_EXLLAMA") == "True":
39
    HAS_EXLLAMA = False
40
elif CAN_EXLLAMA:
OlivierDehaene's avatar
OlivierDehaene committed
41
    try:
Nicolas Patry's avatar
Nicolas Patry committed
42
        if V2:
OlivierDehaene's avatar
OlivierDehaene committed
43
44
45
46
47
            from text_generation_server.utils.gptq.exllamav2 import (
                QuantLinear as ExllamaQuantLinear,
                create_exllama_buffers,
                set_device,
            )
OlivierDehaene's avatar
OlivierDehaene committed
48

Nicolas Patry's avatar
Nicolas Patry committed
49
50
            HAS_EXLLAMA = "2"
        else:
OlivierDehaene's avatar
OlivierDehaene committed
51
52
53
54
55
            from text_generation_server.utils.gptq.exllama import (
                Ex4bitLinear as ExllamaQuantLinear,
                create_exllama_buffers,
                set_device,
            )
OlivierDehaene's avatar
OlivierDehaene committed
56

Nicolas Patry's avatar
Nicolas Patry committed
57
            HAS_EXLLAMA = "1"
OlivierDehaene's avatar
OlivierDehaene committed
58
59
60

    except ImportError:
        pass
61

62
63
64
HAS_EETQ = False
try:
    from EETQ import quant_weights, w8_a16_gemm
OlivierDehaene's avatar
OlivierDehaene committed
65

66
67
68
69
    HAS_EETQ = True
except ImportError:
    pass

70

71
72
73
74
75
76
77
78
79
80
81
82
83
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


84
85
86
87
88
89
90
91
92
93
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = None
    return ln

OlivierDehaene's avatar
OlivierDehaene committed
94

95
96
97
98
99
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
100
101
102
103
104
105
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
106
107
108
109
110
111
112

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = nn.Parameter(bias)
    return conv2d


@classmethod
OlivierDehaene's avatar
OlivierDehaene committed
113
def load_conv2d_no_bias(
OlivierDehaene's avatar
OlivierDehaene committed
114
    cls, prefix, weights, in_channels, out_channels, kernel_size, stride
OlivierDehaene's avatar
OlivierDehaene committed
115
):
116
117
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
118
119
120
121
122
123
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
124
125
126
127
128

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = None
    return conv2d

129

130
131
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
132
torch.nn.LayerNorm.load = load_layer_norm
133
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
134

135
136

class FastLinear(nn.Module):
137
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
138
139
140
        self,
        weight,
        bias,
141
    ) -> None:
142
143
144
145
146
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
147
            self.bias = None
148
149
150
151
152
153

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
154
        else:
155
156
            bias = None
        return cls(weight, bias)
157
158

    def forward(self, input: torch.Tensor) -> torch.Tensor:
159
        return F.linear(input, self.weight, self.bias)
160
161


162
163
class EETQLinear(nn.Module):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
164
165
166
        self,
        weight,
        bias,
167
168
169
    ) -> None:
        super().__init__()
        device = weight.device
170
171
        if weight.dtype != torch.float16:
            weight = weight.to(dtype=torch.float16)
172
173
        weight = torch.t(weight).contiguous().cpu()
        weight, scale = quant_weights(weight, torch.int8, False)
174

175
176
177
178
179
180
181
182
183
184
        self.weight = weight.cuda(device)
        self.scale = scale.cuda(device)
        self.bias = bias.cuda(device) if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = w8_a16_gemm(input, self.weight, self.scale)
        output = output + self.bias if self.bias is not None else output
        return output


Nicolas Patry's avatar
Nicolas Patry committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
    device = weight.device
    # weight, scale = quant_weights(weight, torch.int8, False)
    finfo = torch.finfo(qdtype)
    # Calculate the scale as dtype max divided by absmax
    scale = finfo.max / weight.abs().max().clamp(min=1e-12)
    # scale and clamp the tensor to bring it to
    # the representative range of float8 data type
    # (as default cast is unsaturated)
    qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
    scale = scale.float().reciprocal()
    return qweight, scale


class Fp8Linear(nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        self.dtype = weight.dtype
        self.qweight, self.scale = fp8_quantize(weight)

        self.bias = bias if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        qinput, scale = fp8_quantize(input)
        output, _ = torch._scaled_mm(
            qinput,
            self.qweight.t(),
            out_dtype=self.dtype,
            scale_a=scale,
            scale_b=self.scale,
            bias=self.bias,
        )
        return output


227
class Linear8bitLt(nn.Module):
228
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
229
230
231
232
233
234
235
        self,
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
236
    ):
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
255
        )
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
283
284


Nicolas Patry's avatar
Nicolas Patry committed
285
286
287
288
class Linear4bit(nn.Module):
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
OlivierDehaene's avatar
OlivierDehaene committed
289
290
291
292
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
Nicolas Patry's avatar
Nicolas Patry committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out


321
322
@lru_cache(1)
def warn_deprecate_bnb():
OlivierDehaene's avatar
OlivierDehaene committed
323
324
325
326
    logger.warning(
        "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
    )

327

328
329
330
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
331
332
333
334
    elif quantize == "eetq":
        if HAS_EETQ:
            linear = EETQLinear(weight, bias)
        else:
OlivierDehaene's avatar
OlivierDehaene committed
335
336
337
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
Nicolas Patry's avatar
Nicolas Patry committed
338
339
    elif quantize == "fp8":
        linear = Fp8Linear(weight, bias)
340
    elif quantize == "bitsandbytes":
341
        warn_deprecate_bnb()
342
343
344
345
346
347
348
349
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
Nicolas Patry's avatar
Nicolas Patry committed
350
351
352
353
354
355
356
357
358
359
360
361
    elif quantize == "bitsandbytes-fp4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
362
    elif quantize == "gptq":
363
        try:
364
            qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
365
366
367
368
369
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

370
        if use_exllama:
OlivierDehaene's avatar
OlivierDehaene committed
371
372
373
            linear = ExllamaQuantLinear(
                qweight, qzeros, scales, g_idx, bias, bits, groupsize
            )
374
375
376
377
378
379
380
381
382
383
        else:
            linear = QuantLinear(
                qweight,
                qzeros,
                scales,
                g_idx,
                bias,
                bits,
                groupsize,
            )
384
385
386
387
388
389
390
    elif quantize == "awq":
        try:
            qweight, qzeros, scales, _, bits, groupsize, _ = weight
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `awq` compatible, loader needs to be updated."
            )
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
391
392
393
394
395
396
        if IS_ROCM_SYSTEM:
            raise NotImplementedError(
                "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
                "to use Exllama/GPTQ kernels for AWQ inference."
            )
        if not HAS_AWQ:
OlivierDehaene's avatar
OlivierDehaene committed
397
398
399
            raise NotImplementedError(
                "You do not seem to have awq installed, either install it (cd server &&  make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
            )
OlivierDehaene's avatar
OlivierDehaene committed
400
401
402
403
404
405
406
407
        linear = WQLinear(
            w_bit=bits,
            group_size=groupsize,
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            bias=bias is not None,
        )
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


422
423
424
425
426
427
428
429
430
431
432
433
434
class ResBlock(torch.nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        self.linear = FastLinear.load(
            config, prefix=f"{prefix}.linear", weights=weights, bias=True
        )
        self.act = torch.nn.SiLU()

    def forward(self, x):
        return x + self.act(self.linear(x))


class MedusaModel(torch.nn.Module):
OlivierDehaene's avatar
OlivierDehaene committed
435
    def __init__(self, config, medusa_config, weights):
436
437
438
        super().__init__()
        self.heads = torch.nn.ModuleList(
            [
OlivierDehaene's avatar
OlivierDehaene committed
439
440
                MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
                for i in range(medusa_config["medusa_num_heads"])
441
442
443
444
445
446
447
448
449
            ]
        )

    def forward(self, x):
        speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
        return speculative_logits


class MedusaHead(torch.nn.Module):
OlivierDehaene's avatar
OlivierDehaene committed
450
    def __init__(self, config, medusa_config, prefix, weights):
451
452
453
454
        super().__init__()
        self.blocks = torch.nn.ModuleList(
            [
                ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
OlivierDehaene's avatar
OlivierDehaene committed
455
                for i in range(medusa_config["medusa_num_layers"])
456
457
458
459
460
461
462
463
464
465
466
467
468
469
            ]
        )
        n = len(self.blocks)
        self.out = FastLinear.load(
            config, prefix=f"{prefix}.{n}", weights=weights, bias=False
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.out(x)
        return x


OlivierDehaene's avatar
OlivierDehaene committed
470
class MedusaHeadV1(nn.Module):
471
472
473
474
475
476
477
    def __init__(self, lm_head, medusa):
        super().__init__()
        self.lm_head = lm_head
        self.medusa = medusa

    @staticmethod
    def load(config, prefix: str, weights):
OlivierDehaene's avatar
OlivierDehaene committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        from pathlib import Path
        from safetensors import safe_open
        import json

        use_medusa = config.use_medusa

        medusa_config = str(Path(use_medusa) / "config.json")
        filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

        with open(medusa_config, "r") as f:
            medusa_config = json.load(f)
        routing = weights.routing
        with safe_open(filename, framework="pytorch") as f:
            for k in f.keys():
                if k in routing and routing[k] != filename:
                    raise RuntimeError(
                        f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                    )
                routing[k] = filename

        medusa = MedusaModel(config, medusa_config, weights)
499
        lm_head = TensorParallelHead.load(config, prefix, weights)
OlivierDehaene's avatar
OlivierDehaene committed
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        return MedusaHeadV1(lm_head, medusa)

    def forward(
        self, input: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        logits = self.lm_head(input)
        speculative_logits = self.medusa(input)
        return logits, speculative_logits


class MedusaHeadV2(nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        from pathlib import Path
        from safetensors import safe_open
        import json

        use_medusa = config.use_medusa

        medusa_config = str(Path(use_medusa) / "config.json")
        filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

        with open(medusa_config, "r") as f:
            medusa_config = json.load(f)
        routing = weights.routing
        with safe_open(filename, framework="pytorch") as f:
            for k in f.keys():
                if k in routing and routing[k] != filename:
                    raise RuntimeError(
                        f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                    )
                routing[k] = filename

        self.n_medusa_heads = medusa_config["medusa_num_heads"]

        assert medusa_config["medusa_num_layers"] == 1
        self.linear = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
            dim=0,
            weights=weights,
            bias=True,
        )
        self.process_group = weights.process_group
        self.world_size = self.process_group.size()
        self.rank = self.process_group.rank()

        self.act = torch.nn.SiLU()

        self.lm_head = TensorParallelHead.load(config, prefix, weights)

    def forward(self, x):
        size = x.shape[-1]
        block_size = (size + self.world_size - 1) // self.world_size
        start = self.rank * block_size
        stop = (self.rank + 1) * block_size

        x_block = x[:, start:stop]

        # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
        medusa_res = self.act(self.linear(x)).reshape(
            *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
        )

        # Apply all residual medusa heads
        output = x[:, start:stop].unsqueeze(-2) + medusa_res

        # Gather medusa heads
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)

        # Stack x and medusa residual x
        stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)

        # Compute lm head on x + medusa residual x
        logits = self.lm_head(stacked_x)

        # Finally, split logits from speculative logits
        logits, speculative_logits = torch.split(
            logits, [1, self.n_medusa_heads], dim=-2
        )
        # Squeeze added dimension
        logits = logits.squeeze(-2)

        return logits, speculative_logits


class SpeculativeHead(nn.Module):
    def __init__(self, lm_head, medusa):
        super().__init__()
        self.head = lm_head
        self.medusa = medusa

    @staticmethod
    def load(config, prefix: str, weights):
598
599
        use_medusa = config.use_medusa
        if use_medusa:
OlivierDehaene's avatar
OlivierDehaene committed
600
601
602
603
604
            lm_head = None
            try:
                medusa = MedusaHeadV1.load(config, prefix, weights)
            except:
                medusa = MedusaHeadV2(config, prefix, weights)
605
        else:
OlivierDehaene's avatar
OlivierDehaene committed
606
            lm_head = TensorParallelHead.load(config, prefix, weights)
607
608
609
610
611
612
            medusa = None
        return SpeculativeHead(lm_head, medusa)

    def forward(
        self, input: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
OlivierDehaene's avatar
OlivierDehaene committed
613
614
615
616
617
618
        if self.medusa is not None:
            return self.medusa(input)

        assert self.head is not None
        logits = self.head(input)
        return logits, None
619
620


621
class TensorParallelHead(SuperLayer):
622
    def __init__(self, linear, process_group, should_gather: bool):
623
        super().__init__(linear)
624
        self.process_group = process_group
625
        self.should_gather = should_gather
626
627
628

    @staticmethod
    def load(config, prefix: str, weights):
629
630
631
632
633
634
635
636
637
638
639
640
        if weights.process_group.size() > 1:
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False
641

642
643
        # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
        if config.quantize in ["gptq", "awq", "eetq"]:
644
645
646
            quantize = None
        else:
            quantize = config.quantize
647
        return TensorParallelHead(
648
            get_linear(weight, bias=None, quantize=quantize),
649
            process_group=weights.process_group,
650
            should_gather=should_gather,
651
652
653
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
OlivierDehaene's avatar
OlivierDehaene committed
654
655
656
        if not self.should_gather:
            return super().forward(input)

657
        world_size = self.process_group.size()
OlivierDehaene's avatar
OlivierDehaene committed
658
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
659
660
            out_dim = self.linear.weight.shape[0]

OlivierDehaene's avatar
OlivierDehaene committed
661
662
663
664
665
666
667
668
            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T
669
670
671
672

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
OlivierDehaene's avatar
OlivierDehaene committed
673
                world_out, gather_input, group=self.process_group
674
675
            )

OlivierDehaene's avatar
OlivierDehaene committed
676
677
678
            if input.shape[0] == 1:
                return world_out
            return world_out.T
679

OlivierDehaene's avatar
OlivierDehaene committed
680
681
682
683
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
684
685
686
687
688
689
690
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
xiaobin's avatar
xiaobin committed
691
692
    def load_qkv(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
OlivierDehaene's avatar
OlivierDehaene committed
693
        weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
xiaobin's avatar
xiaobin committed
694
695
696
697
698
699
700
701
        if bias:
            raise NotImplementedError("packed_qkv only implemented for baichuan")
        else:
            bias = None
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)

    @classmethod
702
    def load(cls, config, prefix: str, weights, bias: bool):
703
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
704

705
706
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
707
708
709
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
710

711
712
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
713
            bias = torch.cat(b, dim=dim)
714
715
        else:
            bias = None
716
717
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
718

719
720
721
722

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
723
724
        self.process_group = process_group

725
726
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
727
728
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

729
730
731
732
733
734
735
736
737
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
738

739
    def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
740
        out = super().forward(input)
741
        if self.process_group.size() > 1 and reduce:
742
            torch.distributed.all_reduce(out, group=self.process_group)
743
        return out
744
745


746
747
748
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
749
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
750
751
752
753
754
755
756
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

757
        block_size = (num_embeddings + world_size - 1) // world_size
758
759
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
OlivierDehaene's avatar
OlivierDehaene committed
760
761
762
        self.null_idx = weight.shape[
            0
        ]  # Usually block_size, might be less in non even vocab_size.
763
764
        self.process_group = weights.process_group
        self.reduce = reduce
765
766

        """Additional 0 entry used for masking"""
767
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
768
769
770
771
772
773
774
775
776

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
777
        out = torch.nn.functional.embedding(input, self.weight)
778
        if self.reduce and self.process_group.size() > 1:
779
            torch.distributed.all_reduce(out, group=self.process_group)
780
781
782
783
        return out


try:
fxmarty's avatar
fxmarty committed
784
785
    if IS_CUDA_SYSTEM:
        import dropout_layer_norm
OlivierDehaene's avatar
OlivierDehaene committed
786
787
    elif IS_ROCM_SYSTEM:
        from vllm import layernorm_ops
fxmarty's avatar
fxmarty committed
788
789
    else:
        dropout_layer_norm = None
790
791
792

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
fxmarty's avatar
fxmarty committed
793
            if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual
OlivierDehaene's avatar
OlivierDehaene committed
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856

    class FastRMSNorm(nn.Module):
        def __init__(self, weight: torch.Tensor, eps: float):
            super().__init__()

            self.weight = nn.Parameter(weight)
            self.variance_epsilon = eps

        @classmethod
        def load(cls, prefix, weights, eps=1e-6):
            weight = weights.get_tensor(f"{prefix}.weight")
            return cls(weight, eps)

        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                hidden_states = hidden_states.to(torch.float32)
                variance = hidden_states.pow(2).mean(-1, keepdim=True)
                hidden_states = hidden_states * torch.rsqrt(
                    variance + self.variance_epsilon
                )

                # convert into half-precision if necessary
                if self.weight.dtype in [torch.float16, torch.bfloat16]:
                    hidden_states = hidden_states.to(self.weight.dtype)

                return self.weight * hidden_states, residual
            elif IS_CUDA_SYSTEM:
                # faster post attention rms norm
OlivierDehaene's avatar
OlivierDehaene committed
857
858
859
860
861
                (
                    normed_hidden_states,
                    res,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
OlivierDehaene's avatar
OlivierDehaene committed
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
                    hidden_states,
                    residual,
                    self.weight,
                    None,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.variance_epsilon,
                    1.0,
                    0,
                    None,
                    False,
                    True,  # Activate RMSNorm
                )
                if res is None:
                    res = hidden_states

                return normed_hidden_states, res
            elif IS_ROCM_SYSTEM:
                # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                out = torch.empty_like(hidden_states)
                layernorm_ops.rms_norm(
                    out,
                    hidden_states,
                    self.weight.data,
                    self.variance_epsilon,
                )
                return out, residual
            else:
                raise ValueError(
OlivierDehaene's avatar
OlivierDehaene committed
898
899
                    "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
                )
OlivierDehaene's avatar
OlivierDehaene committed
900

901
902
903
904
except ImportError:
    pass

try:
fxmarty's avatar
fxmarty committed
905
906
907
908
909
    if IS_CUDA_SYSTEM:
        from flash_attn.layers.rotary import RotaryEmbedding
        import rotary_emb
    elif IS_ROCM_SYSTEM:
        from vllm import pos_encoding_ops
910

Nicolas Patry's avatar
Nicolas Patry committed
911
912
    def _create_inv_freq(dim, base, device):
        inv_freq = 1.0 / (
OlivierDehaene's avatar
OlivierDehaene committed
913
            base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
Nicolas Patry's avatar
Nicolas Patry committed
914
915
916
917
918
        )
        return inv_freq

    def _get_rope_config(config):
        if os.getenv("ROPE_SCALING", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
919
920
921
922
            rope_scaling = {
                "type": os.environ["ROPE_SCALING"],
                "factor": float(os.environ["ROPE_FACTOR"]),
            }
Nicolas Patry's avatar
Nicolas Patry committed
923
924
925
            return rope_scaling
        return getattr(config, "rope_scaling", None)

926
    class PositionRotaryEmbedding(nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
927
        def __init__(self, inv_freq, scaling_factor):
928
            super().__init__()
929
            self.inv_freq = inv_freq
930
931
932
933
934
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None
Nicolas Patry's avatar
Nicolas Patry committed
935
936
            self.scaling_factor = scaling_factor
            self.dynamic_args = None
937

OlivierDehaene's avatar
OlivierDehaene committed
938
939
940
941
942
943
944
        def forward(
            self,
            query: torch.Tensor,
            key: torch.Tensor,
            cos: torch.Tensor,
            sin: torch.Tensor,
        ):
fxmarty's avatar
fxmarty committed
945
946
947
948
            # Such controlflows may add some overhead.
            if IS_CUDA_SYSTEM:
                rotary_dim = cos.shape[-1]
                q1 = query[..., :rotary_dim]
OlivierDehaene's avatar
OlivierDehaene committed
949
                q2 = query[..., rotary_dim : 2 * rotary_dim]
fxmarty's avatar
fxmarty committed
950
951
952
953

                rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)

                k1 = key[..., :rotary_dim]
OlivierDehaene's avatar
OlivierDehaene committed
954
                k2 = key[..., rotary_dim : 2 * rotary_dim]
fxmarty's avatar
fxmarty committed
955
956
957
958
959
960
961
962
963

                rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
            elif IS_ROCM_SYSTEM:
                # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
                # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

                head_size = query.shape[-1]

                # Inplace operation, updating query and key.
OlivierDehaene's avatar
OlivierDehaene committed
964
                pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
fxmarty's avatar
fxmarty committed
965
            else:
OlivierDehaene's avatar
OlivierDehaene committed
966
                raise ValueError(
OlivierDehaene's avatar
OlivierDehaene committed
967
968
                    "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
                )
fxmarty's avatar
fxmarty committed
969

970
        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
971
972
973
974
975
976
977
978
979
        def static(cls, config, dim, base, device):
            inv_freq = _create_inv_freq(dim, base, device)
            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
980
981
982
983
984
985
986
                    return DynamicPositionRotaryEmbedding(
                        dim=dim,
                        max_position_embeddings=config.max_position_embeddings,
                        base=base,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
987
988
989
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
OlivierDehaene's avatar
OlivierDehaene committed
990
991
992
                        max_position_embeddings=rope_scaling[
                            "original_max_position_embeddings"
                        ],
Nicolas Patry's avatar
Nicolas Patry committed
993
994
995
996
997
998
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
OlivierDehaene's avatar
OlivierDehaene committed
999
                        beta_slow=1,
Nicolas Patry's avatar
Nicolas Patry committed
1000
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1001
                else:
OlivierDehaene's avatar
OlivierDehaene committed
1002
1003
1004
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1005
            return cls(inv_freq, scaling_factor)
1006
1007

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
1008
        def load(cls, config, prefix, weights):
1009
1010
1011
1012
1013
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
Nicolas Patry's avatar
Nicolas Patry committed
1014
1015
1016
1017
1018
1019
1020
1021

            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
1022
1023
1024
1025
1026
1027
1028
                    return DynamicPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=config.max_position_embeddings,
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1029
1030
1031
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
OlivierDehaene's avatar
OlivierDehaene committed
1032
1033
1034
                        max_position_embeddings=rope_scaling[
                            "original_max_position_embeddings"
                        ],
Nicolas Patry's avatar
Nicolas Patry committed
1035
1036
1037
1038
1039
1040
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
OlivierDehaene's avatar
OlivierDehaene committed
1041
                        beta_slow=1,
Nicolas Patry's avatar
Nicolas Patry committed
1042
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1043
                else:
OlivierDehaene's avatar
OlivierDehaene committed
1044
1045
1046
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1047
            return cls(inv_freq, scaling_factor)
1048

1049
1050
1051
1052
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
1053
1054
1055
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
1056
1057
1058
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Nicolas Patry's avatar
Nicolas Patry committed
1059
1060
                if self.scaling_factor is not None:
                    t /= self.scaling_factor
1061
1062
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Nicolas Patry's avatar
Nicolas Patry committed
1063

1064
1065
1066
1067
1068
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
OlivierDehaene's avatar
OlivierDehaene committed
1069
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
1070
1071
1072
1073
        ):
            """
            Return cos and sin for the asked position ids
            """
fxmarty's avatar
fxmarty committed
1074
1075
1076
1077
1078
            if IS_ROCM_SYSTEM:
                # For RoCm, we always use float cos/sin to avoid a cast.
                # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
                # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
                dtype = torch.float32
1079
1080
1081
1082
1083

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
fxmarty's avatar
fxmarty committed
1084
            # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
1085
1086
            return cos.unsqueeze(1), sin.unsqueeze(1)

Nicolas Patry's avatar
Nicolas Patry committed
1087
1088
    class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
Nicolas Patry's avatar
Nicolas Patry committed
1089
            inv_freq = _create_inv_freq(dim, base, device)
Nicolas Patry's avatar
Nicolas Patry committed
1090
1091
1092
1093
1094
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base

OlivierDehaene's avatar
OlivierDehaene committed
1095
        def _update_cos_sin_cache(self, dtype, device, seqlen):
Nicolas Patry's avatar
Nicolas Patry committed
1096
1097
1098
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
1099
1100
1101
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
Nicolas Patry's avatar
Nicolas Patry committed
1102
1103
            ):
                if seqlen > self.max_position_embeddings:
OlivierDehaene's avatar
OlivierDehaene committed
1104
                    newbase = self.base * (
OlivierDehaene's avatar
OlivierDehaene committed
1105
1106
                        (self.scaling_factor * seqlen / self.max_position_embeddings)
                        - (self.scaling_factor - 1)
OlivierDehaene's avatar
OlivierDehaene committed
1107
1108
1109
1110
                    ) ** (self.dim / (self.dim - 2))
                    self.inv_freq = _create_inv_freq(
                        self.dim, newbase, self.inv_freq.device
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1111
1112
1113
1114
1115
1116
1117
1118
1119
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

Nicolas Patry's avatar
Nicolas Patry committed
1120
1121
    # Inverse dim formula to find dim based on number of rotations
    import math
OlivierDehaene's avatar
OlivierDehaene committed
1122

OlivierDehaene's avatar
OlivierDehaene committed
1123
1124
1125
1126
1127
1128
    def find_correction_dim(
        num_rotations, dim, base=10000, max_position_embeddings=2048
    ):
        return (
            dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
        ) / (2 * math.log(base))
Nicolas Patry's avatar
Nicolas Patry committed
1129
1130

    # Find dim range bounds based on rotations
OlivierDehaene's avatar
OlivierDehaene committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
    def find_correction_range(
        low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
    ):
        low = math.floor(
            find_correction_dim(low_rot, dim, base, max_position_embeddings)
        )
        high = math.ceil(
            find_correction_dim(high_rot, dim, base, max_position_embeddings)
        )
OlivierDehaene's avatar
OlivierDehaene committed
1140
1141
        return max(low, 0), min(high, dim - 1)  # Clamp values just in case

Nicolas Patry's avatar
Nicolas Patry committed
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
    def linear_ramp_mask(min, max, dim):
        if min == max:
            max += 0.001  # Prevent singularity

        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    def get_mscale(scale=1):
        if scale <= 1:
            return 1.0
        return 0.1 * math.log(scale) + 1.0

    class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
OlivierDehaene's avatar
OlivierDehaene committed
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        def __init__(
            self,
            dim,
            max_position_embeddings,
            base,
            device,
            scaling_factor,
            *,
            extrapolation_factor,
            attn_factor,
            beta_fast,
            beta_slow,
        ):
Nicolas Patry's avatar
Nicolas Patry committed
1169
1170
1171
1172
1173
1174
1175
1176
1177
            inv_freq = _create_inv_freq(dim, base, device)
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base
            self.extrapolation_factor = extrapolation_factor
            self.attn_factor = attn_factor
            self.beta_fast = beta_fast
            self.beta_slow = beta_slow
OlivierDehaene's avatar
OlivierDehaene committed
1178
1179
1180
            self.mscale = float(
                get_mscale(self.scaling_factor) * self.attn_factor
            )  # Get n-d magnitude scaling corrected for interpolation
Nicolas Patry's avatar
Nicolas Patry committed
1181
1182
1183
1184
1185

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
1186
1187
1188
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
Nicolas Patry's avatar
Nicolas Patry committed
1189
1190
1191
1192
1193
1194
1195
            ):
                if seqlen > self.max_position_embeddings:
                    inv_freq_extrapolation = _create_inv_freq(
                        self.dim, self.base, self.inv_freq.device
                    )
                    freqs = 1.0 / inv_freq_extrapolation
                    inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
OlivierDehaene's avatar
OlivierDehaene committed
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
                    low, high = find_correction_range(
                        self.beta_fast,
                        self.beta_slow,
                        self.dim,
                        self.base,
                        self.max_position_embeddings,
                    )
                    inv_freq_mask = (
                        1
                        - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
                    ) * self.extrapolation_factor  # Get n-d rotational scaling corrected for extrapolation
                    inv_freq = (
                        inv_freq_interpolation * (1 - inv_freq_mask)
                        + inv_freq_extrapolation * inv_freq_mask
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1211
1212

                    self.inv_freq = inv_freq
OlivierDehaene's avatar
OlivierDehaene committed
1213
1214
1215
                    self.mscale = float(
                        get_mscale(self.scaling_factor) * self.attn_factor
                    )  # Get n-d magnitude scaling corrected for interpolation
Nicolas Patry's avatar
Nicolas Patry committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225

                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)

1226
1227
except ImportError:
    pass