layers.py 34.5 KB
Newer Older
1
import os
2
import torch
3
import torch.distributed
4
5

from torch import nn
6
from torch.nn import functional as F
7
from typing import List
8
9
from loguru import logger
from functools import lru_cache
10
11
12

HAS_BITS_AND_BYTES = True
try:
13
    import bitsandbytes as bnb
Nicolas Patry's avatar
Nicolas Patry committed
14
    from bitsandbytes.nn import Int8Params, Params4bit
15
except ImportError:
16
17
    HAS_BITS_AND_BYTES = False

18
19
from accelerate import init_empty_weights

20
from text_generation_server.utils.gptq.quant_linear import QuantLinear
OlivierDehaene's avatar
OlivierDehaene committed
21
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
22
23

HAS_AWQ = True
OlivierDehaene's avatar
OlivierDehaene committed
24
try:
25
26
27
28
    from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
    HAS_AWQ = False

29
try:
30
31
32
    major, _minor = torch.cuda.get_device_capability()
except Exception:
    major = 1
Nicolas Patry's avatar
Nicolas Patry committed
33

34
35
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
Nicolas Patry's avatar
Nicolas Patry committed
36
37
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
OlivierDehaene's avatar
OlivierDehaene committed
38
39
40
    logger.warning(
        "Disabling exllama v2 and using v1 instead because there are issues when sharding"
    )
Nicolas Patry's avatar
Nicolas Patry committed
41
42
    V2 = False

43
if os.getenv("DISABLE_EXLLAMA") == "True":
44
    HAS_EXLLAMA = False
45
elif CAN_EXLLAMA:
OlivierDehaene's avatar
OlivierDehaene committed
46
    try:
Nicolas Patry's avatar
Nicolas Patry committed
47
        if V2:
OlivierDehaene's avatar
OlivierDehaene committed
48
49
50
51
52
            from text_generation_server.utils.gptq.exllamav2 import (
                QuantLinear as ExllamaQuantLinear,
                create_exllama_buffers,
                set_device,
            )
OlivierDehaene's avatar
OlivierDehaene committed
53

Nicolas Patry's avatar
Nicolas Patry committed
54
55
            HAS_EXLLAMA = "2"
        else:
OlivierDehaene's avatar
OlivierDehaene committed
56
57
58
59
60
            from text_generation_server.utils.gptq.exllama import (
                Ex4bitLinear as ExllamaQuantLinear,
                create_exllama_buffers,
                set_device,
            )
OlivierDehaene's avatar
OlivierDehaene committed
61

Nicolas Patry's avatar
Nicolas Patry committed
62
            HAS_EXLLAMA = "1"
OlivierDehaene's avatar
OlivierDehaene committed
63
64
65

    except ImportError:
        pass
66

67
68
69
HAS_EETQ = False
try:
    from EETQ import quant_weights, w8_a16_gemm
OlivierDehaene's avatar
OlivierDehaene committed
70

71
72
73
74
    HAS_EETQ = True
except ImportError:
    pass

75

76
77
78
79
80
81
82
83
84
85
86
87
88
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


89
90
91
92
93
94
95
96
97
98
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = None
    return ln

OlivierDehaene's avatar
OlivierDehaene committed
99

100
101
102
103
104
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
105
106
107
108
109
110
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
111
112
113
114
115
116
117

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = nn.Parameter(bias)
    return conv2d


@classmethod
OlivierDehaene's avatar
OlivierDehaene committed
118
def load_conv2d_no_bias(
OlivierDehaene's avatar
OlivierDehaene committed
119
    cls, prefix, weights, in_channels, out_channels, kernel_size, stride
OlivierDehaene's avatar
OlivierDehaene committed
120
):
121
122
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
OlivierDehaene's avatar
OlivierDehaene committed
123
124
125
126
127
128
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
129
130
131
132
133

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = None
    return conv2d

134

135
136
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
137
torch.nn.LayerNorm.load = load_layer_norm
138
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
139

140
141

class FastLinear(nn.Module):
142
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
143
144
145
        self,
        weight,
        bias,
146
    ) -> None:
147
148
149
150
151
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
152
            self.bias = None
153
154
155
156
157
158

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
159
        else:
160
161
            bias = None
        return cls(weight, bias)
162
163

    def forward(self, input: torch.Tensor) -> torch.Tensor:
164
        return F.linear(input, self.weight, self.bias)
165
166


167
168
class EETQLinear(nn.Module):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
169
170
171
        self,
        weight,
        bias,
172
173
174
175
176
    ) -> None:
        super().__init__()
        device = weight.device
        weight = torch.t(weight).contiguous().cpu()
        weight, scale = quant_weights(weight, torch.int8, False)
177

178
179
180
181
182
183
184
185
186
187
        self.weight = weight.cuda(device)
        self.scale = scale.cuda(device)
        self.bias = bias.cuda(device) if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = w8_a16_gemm(input, self.weight, self.scale)
        output = output + self.bias if self.bias is not None else output
        return output


188
class Linear8bitLt(nn.Module):
189
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
190
191
192
193
194
195
196
        self,
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
197
    ):
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
216
        )
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
244
245


Nicolas Patry's avatar
Nicolas Patry committed
246
247
248
249
class Linear4bit(nn.Module):
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
OlivierDehaene's avatar
OlivierDehaene committed
250
251
252
253
            weight.data,
            requires_grad=False,
            compress_statistics=True,
            quant_type=quant_type,
Nicolas Patry's avatar
Nicolas Patry committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out


282
283
@lru_cache(1)
def warn_deprecate_bnb():
OlivierDehaene's avatar
OlivierDehaene committed
284
285
286
287
    logger.warning(
        "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
    )

288

289
290
291
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
292
293
294
295
    elif quantize == "eetq":
        if HAS_EETQ:
            linear = EETQLinear(weight, bias)
        else:
OlivierDehaene's avatar
OlivierDehaene committed
296
297
298
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
299
    elif quantize == "bitsandbytes":
300
        warn_deprecate_bnb()
301
302
303
304
305
306
307
308
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
Nicolas Patry's avatar
Nicolas Patry committed
309
310
311
312
313
314
315
316
317
318
319
320
    elif quantize == "bitsandbytes-fp4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
321
    elif quantize == "gptq":
322
        try:
323
            qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
324
325
326
327
328
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

329
        if use_exllama:
OlivierDehaene's avatar
OlivierDehaene committed
330
331
332
            linear = ExllamaQuantLinear(
                qweight, qzeros, scales, g_idx, bias, bits, groupsize
            )
333
334
335
336
337
338
339
340
341
342
        else:
            linear = QuantLinear(
                qweight,
                qzeros,
                scales,
                g_idx,
                bias,
                bits,
                groupsize,
            )
343
344
345
346
347
348
349
    elif quantize == "awq":
        try:
            qweight, qzeros, scales, _, bits, groupsize, _ = weight
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `awq` compatible, loader needs to be updated."
            )
OlivierDehaene's avatar
OlivierDehaene committed
350
351
352
353
354
355
356
357
        linear = WQLinear(
            w_bit=bits,
            group_size=groupsize,
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            bias=bias is not None,
        )
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


class TensorParallelHead(SuperLayer):
373
    def __init__(self, linear, process_group, should_gather: bool):
374
        super().__init__(linear)
375
        self.process_group = process_group
376
        self.should_gather = should_gather
377
378
379

    @staticmethod
    def load(config, prefix: str, weights):
380
381
382
383
384
385
386
387
388
389
390
391
        if weights.process_group.size() > 1:
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False
392

393
394
        # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
        if config.quantize in ["gptq", "awq", "eetq"]:
395
396
397
            quantize = None
        else:
            quantize = config.quantize
398
        return TensorParallelHead(
399
            get_linear(weight, bias=None, quantize=quantize),
400
            process_group=weights.process_group,
401
            should_gather=should_gather,
402
403
404
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
OlivierDehaene's avatar
OlivierDehaene committed
405
406
407
        if not self.should_gather:
            return super().forward(input)

408
        world_size = self.process_group.size()
OlivierDehaene's avatar
OlivierDehaene committed
409
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
410
411
            out_dim = self.linear.weight.shape[0]

OlivierDehaene's avatar
OlivierDehaene committed
412
413
414
415
416
417
418
419
            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T
420
421
422
423

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
OlivierDehaene's avatar
OlivierDehaene committed
424
                world_out, gather_input, group=self.process_group
425
426
            )

OlivierDehaene's avatar
OlivierDehaene committed
427
428
429
            if input.shape[0] == 1:
                return world_out
            return world_out.T
430

OlivierDehaene's avatar
OlivierDehaene committed
431
432
433
434
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
435
436
437
438
439
440
441
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
xiaobin's avatar
xiaobin committed
442
443
    def load_qkv(cls, config, prefix: str, weights, bias: bool):
        """Specific method when the QKV was joined after the fact"""
OlivierDehaene's avatar
OlivierDehaene committed
444
        weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
xiaobin's avatar
xiaobin committed
445
446
447
448
449
450
451
452
        if bias:
            raise NotImplementedError("packed_qkv only implemented for baichuan")
        else:
            bias = None
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)

    @classmethod
453
    def load(cls, config, prefix: str, weights, bias: bool):
454
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
455

456
457
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
458
459
460
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
461

462
463
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
464
            bias = torch.cat(b, dim=dim)
465
466
        else:
            bias = None
467
468
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
469

470
471
472
473

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
474
475
        self.process_group = process_group

476
477
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
478
479
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

480
481
482
483
484
485
486
487
488
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
489

490
    def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
491
        out = super().forward(input)
492
        if self.process_group.size() > 1 and reduce:
493
            torch.distributed.all_reduce(out, group=self.process_group)
494
        return out
495
496


497
498
499
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
500
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
501
502
503
504
505
506
507
508
509
510
511
512
513
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

        block_size = num_embeddings // world_size
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
        self.null_idx = block_size
        self.process_group = weights.process_group
        self.reduce = reduce
514
515

        """Additional 0 entry used for masking"""
516
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
517
518
519
520
521
522
523
524
525

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
526
        out = torch.nn.functional.embedding(input, self.weight)
527
        if self.reduce and self.process_group.size() > 1:
528
            torch.distributed.all_reduce(out, group=self.process_group)
529
530
531
532
        return out


try:
fxmarty's avatar
fxmarty committed
533
534
    if IS_CUDA_SYSTEM:
        import dropout_layer_norm
OlivierDehaene's avatar
OlivierDehaene committed
535
536
    elif IS_ROCM_SYSTEM:
        from vllm import layernorm_ops
fxmarty's avatar
fxmarty committed
537
538
    else:
        dropout_layer_norm = None
539
540
541

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
fxmarty's avatar
fxmarty committed
542
            if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual
OlivierDehaene's avatar
OlivierDehaene committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605

    class FastRMSNorm(nn.Module):
        def __init__(self, weight: torch.Tensor, eps: float):
            super().__init__()

            self.weight = nn.Parameter(weight)
            self.variance_epsilon = eps

        @classmethod
        def load(cls, prefix, weights, eps=1e-6):
            weight = weights.get_tensor(f"{prefix}.weight")
            return cls(weight, eps)

        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                hidden_states = hidden_states.to(torch.float32)
                variance = hidden_states.pow(2).mean(-1, keepdim=True)
                hidden_states = hidden_states * torch.rsqrt(
                    variance + self.variance_epsilon
                )

                # convert into half-precision if necessary
                if self.weight.dtype in [torch.float16, torch.bfloat16]:
                    hidden_states = hidden_states.to(self.weight.dtype)

                return self.weight * hidden_states, residual
            elif IS_CUDA_SYSTEM:
                # faster post attention rms norm
OlivierDehaene's avatar
OlivierDehaene committed
606
607
608
609
610
                (
                    normed_hidden_states,
                    res,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
OlivierDehaene's avatar
OlivierDehaene committed
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
                    hidden_states,
                    residual,
                    self.weight,
                    None,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.variance_epsilon,
                    1.0,
                    0,
                    None,
                    False,
                    True,  # Activate RMSNorm
                )
                if res is None:
                    res = hidden_states

                return normed_hidden_states, res
            elif IS_ROCM_SYSTEM:
                # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                out = torch.empty_like(hidden_states)
                layernorm_ops.rms_norm(
                    out,
                    hidden_states,
                    self.weight.data,
                    self.variance_epsilon,
                )
                return out, residual
            else:
                raise ValueError(
OlivierDehaene's avatar
OlivierDehaene committed
647
648
                    "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
                )
OlivierDehaene's avatar
OlivierDehaene committed
649

650
651
652
653
except ImportError:
    pass

try:
fxmarty's avatar
fxmarty committed
654
655
656
657
658
    if IS_CUDA_SYSTEM:
        from flash_attn.layers.rotary import RotaryEmbedding
        import rotary_emb
    elif IS_ROCM_SYSTEM:
        from vllm import pos_encoding_ops
659

Nicolas Patry's avatar
Nicolas Patry committed
660
661
    def _create_inv_freq(dim, base, device):
        inv_freq = 1.0 / (
OlivierDehaene's avatar
OlivierDehaene committed
662
            base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
Nicolas Patry's avatar
Nicolas Patry committed
663
664
665
666
667
        )
        return inv_freq

    def _get_rope_config(config):
        if os.getenv("ROPE_SCALING", None) is not None:
OlivierDehaene's avatar
OlivierDehaene committed
668
669
670
671
            rope_scaling = {
                "type": os.environ["ROPE_SCALING"],
                "factor": float(os.environ["ROPE_FACTOR"]),
            }
Nicolas Patry's avatar
Nicolas Patry committed
672
673
674
            return rope_scaling
        return getattr(config, "rope_scaling", None)

675
    class PositionRotaryEmbedding(nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
676
        def __init__(self, inv_freq, scaling_factor):
677
            super().__init__()
678
            self.inv_freq = inv_freq
679
680
681
682
683
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None
Nicolas Patry's avatar
Nicolas Patry committed
684
685
            self.scaling_factor = scaling_factor
            self.dynamic_args = None
686

OlivierDehaene's avatar
OlivierDehaene committed
687
688
689
690
691
692
693
        def forward(
            self,
            query: torch.Tensor,
            key: torch.Tensor,
            cos: torch.Tensor,
            sin: torch.Tensor,
        ):
fxmarty's avatar
fxmarty committed
694
695
696
697
            # Such controlflows may add some overhead.
            if IS_CUDA_SYSTEM:
                rotary_dim = cos.shape[-1]
                q1 = query[..., :rotary_dim]
OlivierDehaene's avatar
OlivierDehaene committed
698
                q2 = query[..., rotary_dim : 2 * rotary_dim]
fxmarty's avatar
fxmarty committed
699
700
701
702

                rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)

                k1 = key[..., :rotary_dim]
OlivierDehaene's avatar
OlivierDehaene committed
703
                k2 = key[..., rotary_dim : 2 * rotary_dim]
fxmarty's avatar
fxmarty committed
704
705
706
707
708
709
710
711
712

                rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
            elif IS_ROCM_SYSTEM:
                # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
                # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

                head_size = query.shape[-1]

                # Inplace operation, updating query and key.
OlivierDehaene's avatar
OlivierDehaene committed
713
                pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
fxmarty's avatar
fxmarty committed
714
            else:
OlivierDehaene's avatar
OlivierDehaene committed
715
                raise ValueError(
OlivierDehaene's avatar
OlivierDehaene committed
716
717
                    "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
                )
fxmarty's avatar
fxmarty committed
718

719
        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
720
721
722
723
724
725
726
727
728
        def static(cls, config, dim, base, device):
            inv_freq = _create_inv_freq(dim, base, device)
            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
729
730
731
732
733
734
735
                    return DynamicPositionRotaryEmbedding(
                        dim=dim,
                        max_position_embeddings=config.max_position_embeddings,
                        base=base,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
736
737
738
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
OlivierDehaene's avatar
OlivierDehaene committed
739
740
741
                        max_position_embeddings=rope_scaling[
                            "original_max_position_embeddings"
                        ],
Nicolas Patry's avatar
Nicolas Patry committed
742
743
744
745
746
747
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
OlivierDehaene's avatar
OlivierDehaene committed
748
                        beta_slow=1,
Nicolas Patry's avatar
Nicolas Patry committed
749
                    )
Nicolas Patry's avatar
Nicolas Patry committed
750
                else:
OlivierDehaene's avatar
OlivierDehaene committed
751
752
753
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
754
            return cls(inv_freq, scaling_factor)
755
756

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
757
        def load(cls, config, prefix, weights):
758
759
760
761
762
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
Nicolas Patry's avatar
Nicolas Patry committed
763
764
765
766
767
768
769
770

            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
OlivierDehaene's avatar
OlivierDehaene committed
771
772
773
774
775
776
777
                    return DynamicPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
                        max_position_embeddings=config.max_position_embeddings,
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
778
779
780
                elif rope_scaling["type"] == "yarn":
                    return YarnPositionRotaryEmbedding(
                        dim=2 * inv_freq.shape[0],
OlivierDehaene's avatar
OlivierDehaene committed
781
782
783
                        max_position_embeddings=rope_scaling[
                            "original_max_position_embeddings"
                        ],
Nicolas Patry's avatar
Nicolas Patry committed
784
785
786
787
788
789
                        base=10000.0,
                        device=inv_freq.device,
                        scaling_factor=scaling_factor,
                        extrapolation_factor=1,
                        attn_factor=1,
                        beta_fast=32,
OlivierDehaene's avatar
OlivierDehaene committed
790
                        beta_slow=1,
Nicolas Patry's avatar
Nicolas Patry committed
791
                    )
Nicolas Patry's avatar
Nicolas Patry committed
792
                else:
OlivierDehaene's avatar
OlivierDehaene committed
793
794
795
                    raise NotImplementedError(
                        f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
                    )
Nicolas Patry's avatar
Nicolas Patry committed
796
            return cls(inv_freq, scaling_factor)
797

798
799
800
801
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
802
803
804
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
805
806
807
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Nicolas Patry's avatar
Nicolas Patry committed
808
809
                if self.scaling_factor is not None:
                    t /= self.scaling_factor
810
811
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Nicolas Patry's avatar
Nicolas Patry committed
812

813
814
815
816
817
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
OlivierDehaene's avatar
OlivierDehaene committed
818
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
819
820
821
822
        ):
            """
            Return cos and sin for the asked position ids
            """
fxmarty's avatar
fxmarty committed
823
824
825
826
827
            if IS_ROCM_SYSTEM:
                # For RoCm, we always use float cos/sin to avoid a cast.
                # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
                # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
                dtype = torch.float32
828
829
830
831
832

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
fxmarty's avatar
fxmarty committed
833
            # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
834
835
            return cos.unsqueeze(1), sin.unsqueeze(1)

Nicolas Patry's avatar
Nicolas Patry committed
836
837
    class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
Nicolas Patry's avatar
Nicolas Patry committed
838
            inv_freq = _create_inv_freq(dim, base, device)
Nicolas Patry's avatar
Nicolas Patry committed
839
840
841
842
843
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base

OlivierDehaene's avatar
OlivierDehaene committed
844
        def _update_cos_sin_cache(self, dtype, device, seqlen):
Nicolas Patry's avatar
Nicolas Patry committed
845
846
847
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
848
849
850
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
Nicolas Patry's avatar
Nicolas Patry committed
851
852
            ):
                if seqlen > self.max_position_embeddings:
OlivierDehaene's avatar
OlivierDehaene committed
853
                    newbase = self.base * (
OlivierDehaene's avatar
OlivierDehaene committed
854
855
                        (self.scaling_factor * seqlen / self.max_position_embeddings)
                        - (self.scaling_factor - 1)
OlivierDehaene's avatar
OlivierDehaene committed
856
857
858
859
                    ) ** (self.dim / (self.dim - 2))
                    self.inv_freq = _create_inv_freq(
                        self.dim, newbase, self.inv_freq.device
                    )
Nicolas Patry's avatar
Nicolas Patry committed
860
861
862
863
864
865
866
867
868
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

Nicolas Patry's avatar
Nicolas Patry committed
869
870
    # Inverse dim formula to find dim based on number of rotations
    import math
OlivierDehaene's avatar
OlivierDehaene committed
871

OlivierDehaene's avatar
OlivierDehaene committed
872
873
874
875
876
877
    def find_correction_dim(
        num_rotations, dim, base=10000, max_position_embeddings=2048
    ):
        return (
            dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
        ) / (2 * math.log(base))
Nicolas Patry's avatar
Nicolas Patry committed
878
879

    # Find dim range bounds based on rotations
OlivierDehaene's avatar
OlivierDehaene committed
880
881
882
883
884
885
886
887
888
    def find_correction_range(
        low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
    ):
        low = math.floor(
            find_correction_dim(low_rot, dim, base, max_position_embeddings)
        )
        high = math.ceil(
            find_correction_dim(high_rot, dim, base, max_position_embeddings)
        )
OlivierDehaene's avatar
OlivierDehaene committed
889
890
        return max(low, 0), min(high, dim - 1)  # Clamp values just in case

Nicolas Patry's avatar
Nicolas Patry committed
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    def linear_ramp_mask(min, max, dim):
        if min == max:
            max += 0.001  # Prevent singularity

        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    def get_mscale(scale=1):
        if scale <= 1:
            return 1.0
        return 0.1 * math.log(scale) + 1.0

    class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
OlivierDehaene's avatar
OlivierDehaene committed
905
906
907
908
909
910
911
912
913
914
915
916
917
        def __init__(
            self,
            dim,
            max_position_embeddings,
            base,
            device,
            scaling_factor,
            *,
            extrapolation_factor,
            attn_factor,
            beta_fast,
            beta_slow,
        ):
Nicolas Patry's avatar
Nicolas Patry committed
918
919
920
921
922
923
924
925
926
            inv_freq = _create_inv_freq(dim, base, device)
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base
            self.extrapolation_factor = extrapolation_factor
            self.attn_factor = attn_factor
            self.beta_fast = beta_fast
            self.beta_slow = beta_slow
OlivierDehaene's avatar
OlivierDehaene committed
927
928
929
            self.mscale = float(
                get_mscale(self.scaling_factor) * self.attn_factor
            )  # Get n-d magnitude scaling corrected for interpolation
Nicolas Patry's avatar
Nicolas Patry committed
930
931
932
933
934

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
OlivierDehaene's avatar
OlivierDehaene committed
935
936
937
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
Nicolas Patry's avatar
Nicolas Patry committed
938
939
940
941
942
943
944
            ):
                if seqlen > self.max_position_embeddings:
                    inv_freq_extrapolation = _create_inv_freq(
                        self.dim, self.base, self.inv_freq.device
                    )
                    freqs = 1.0 / inv_freq_extrapolation
                    inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
OlivierDehaene's avatar
OlivierDehaene committed
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
                    low, high = find_correction_range(
                        self.beta_fast,
                        self.beta_slow,
                        self.dim,
                        self.base,
                        self.max_position_embeddings,
                    )
                    inv_freq_mask = (
                        1
                        - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
                    ) * self.extrapolation_factor  # Get n-d rotational scaling corrected for extrapolation
                    inv_freq = (
                        inv_freq_interpolation * (1 - inv_freq_mask)
                        + inv_freq_extrapolation * inv_freq_mask
                    )
Nicolas Patry's avatar
Nicolas Patry committed
960
961

                    self.inv_freq = inv_freq
OlivierDehaene's avatar
OlivierDehaene committed
962
963
964
                    self.mscale = float(
                        get_mscale(self.scaling_factor) * self.attn_factor
                    )  # Get n-d magnitude scaling corrected for interpolation
Nicolas Patry's avatar
Nicolas Patry committed
965
966
967
968
969
970
971
972
973
974

                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)

975
976
except ImportError:
    pass