layers.py 20.3 KB
Newer Older
1
import os
2
import torch
3
import torch.distributed
4
5

from torch import nn
6
from torch.nn import functional as F
7
from typing import List
8
9
10

HAS_BITS_AND_BYTES = True
try:
11
    import bitsandbytes as bnb
Nicolas Patry's avatar
Nicolas Patry committed
12
    from bitsandbytes.nn import Int8Params, Params4bit
13
14

except ImportError:
15
16
    HAS_BITS_AND_BYTES = False

17
18
from accelerate import init_empty_weights

19
from text_generation_server.utils.gptq.quant_linear import QuantLinear
20

21
try:
22
23
24
25
26
27
    major, _minor = torch.cuda.get_device_capability()
except Exception:
    major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
if os.getenv("DISABLE_EXLLAMA") == "True":
28
    HAS_EXLLAMA = False
29
30
31
32
33
34
elif CAN_EXLLAMA:
        try:
            from text_generation_server.utils.gptq.exllama import Ex4bitLinear
            HAS_EXLLAMA = True
        except ImportError:
            pass
35

36
from typing import Optional
37
38
39
40
41
42
43
44
45
46
47
48
49
50

# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


51
52
53
54
55
56
57
58
59
60
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = None
    return ln

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = nn.Parameter(bias)
    return conv2d


@classmethod
def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)

    conv2d.weight = nn.Parameter(weight)
    conv2d.bias = None
    return conv2d

83

84
85
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
86
torch.nn.LayerNorm.load = load_layer_norm
87
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
88

89
90

class FastLinear(nn.Module):
91
92
    def __init__(
        self,
93
94
        weight,
        bias,
95
    ) -> None:
96
97
98
99
100
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
101
            self.bias = None
102
103
104
105
106
107

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
108
        else:
109
110
            bias = None
        return cls(weight, bias)
111
112

    def forward(self, input: torch.Tensor) -> torch.Tensor:
113
        return F.linear(input, self.weight, self.bias)
114
115


116
class Linear8bitLt(nn.Module):
117
118
    def __init__(
        self,
119
120
121
122
123
124
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
125
    ):
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
144
        )
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
172
173


Nicolas Patry's avatar
Nicolas Patry committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class Linear4bit(nn.Module):
    def __init__(self, weight, bias, quant_type):
        super().__init__()
        self.weight = Params4bit(
            weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
        )
        self.compute_dtype = None
        self.weight.cuda(weight.device)
        self.bias = bias

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            print(
                "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
            )
        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out


207
208
209
210
211
212
213
214
215
216
217
218
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
    elif quantize == "bitsandbytes":
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
Nicolas Patry's avatar
Nicolas Patry committed
219
220
221
222
223
224
225
226
227
228
229
230
    elif quantize == "bitsandbytes-fp4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
231
    elif quantize == "gptq":
232
        try:
233
            qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
234
235
236
237
238
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

239
240
241
242
243
244
245
246
247
248
249
250
        if use_exllama:
            linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
        else:
            linear = QuantLinear(
                qweight,
                qzeros,
                scales,
                g_idx,
                bias,
                bits,
                groupsize,
            )
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


class TensorParallelHead(SuperLayer):
266
    def __init__(self, linear, process_group, should_gather: bool):
267
        super().__init__(linear)
268
        self.process_group = process_group
269
        self.should_gather = should_gather
270
271
272

    @staticmethod
    def load(config, prefix: str, weights):
273
274
275
276
277
278
279
280
281
282
283
284
        if weights.process_group.size() > 1:
            try:
                weight = weights.get_sharded(f"{prefix}.weight", dim=0)
                should_gather = True
            except AssertionError:
                # If the vocab size is not divisible by number of shards
                # just load the entire thing.
                weight = weights.get_tensor(f"{prefix}.weight")
                should_gather = False
        else:
            weight = weights.get_tensor(f"{prefix}.weight")
            should_gather = False
285
286
287
288
289
290

        # GPTQ doesn't quantize heads (nor embeddings)
        if config.quantize == "gptq":
            quantize = None
        else:
            quantize = config.quantize
291
        return TensorParallelHead(
292
            get_linear(weight, bias=None, quantize=quantize),
293
            process_group=weights.process_group,
294
            should_gather=should_gather,
295
296
297
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
OlivierDehaene's avatar
OlivierDehaene committed
298
299
300
        if not self.should_gather:
            return super().forward(input)

301
        world_size = self.process_group.size()
OlivierDehaene's avatar
OlivierDehaene committed
302
        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
303
304
            out_dim = self.linear.weight.shape[0]

OlivierDehaene's avatar
OlivierDehaene committed
305
306
307
308
309
310
311
312
            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T
313
314
315
316

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
OlivierDehaene's avatar
OlivierDehaene committed
317
                world_out, gather_input, group=self.process_group
318
319
            )

OlivierDehaene's avatar
OlivierDehaene committed
320
321
322
            if input.shape[0] == 1:
                return world_out
            return world_out.T
323

OlivierDehaene's avatar
OlivierDehaene committed
324
325
326
327
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
328
329
330
331
332
333
334
335
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
336
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
337

338
339
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
340
341
342
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
343

344
345
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
346
            bias = torch.cat(b, dim=dim)
347
348
        else:
            bias = None
349
350
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
351

352
353
354
355

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
356
357
        self.process_group = process_group

358
359
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
360
361
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

362
363
364
365
366
367
368
369
370
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
371

372
373
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = super().forward(input)
374
375
        if self.process_group.size() > 1:
            torch.distributed.all_reduce(out, group=self.process_group)
376
        return out
377
378


379
380
381
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
382
        weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
383
384
385
386
387
388
389
390
391
392
393
394
395
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

        block_size = num_embeddings // world_size
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
        self.null_idx = block_size
        self.process_group = weights.process_group
        self.reduce = reduce
396
397

        """Additional 0 entry used for masking"""
398
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
399
400
401
402
403
404
405
406
407

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
408
        out = torch.nn.functional.embedding(input, self.weight)
409
        if self.reduce and self.process_group.size() > 1:
410
            torch.distributed.all_reduce(out, group=self.process_group)
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        return out


try:
    import dropout_layer_norm

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual

except ImportError:
    pass


try:
    from flash_attn.layers.rotary import RotaryEmbedding
    import rotary_emb

Nicolas Patry's avatar
Nicolas Patry committed
460
461
462
463
464
465
466
467
468
469
470
471
472
    def _create_inv_freq(dim, base, device):
        inv_freq = 1.0 / (
            base
            ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
        )
        return inv_freq

    def _get_rope_config(config):
        if os.getenv("ROPE_SCALING", None) is not None:
            rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
            return rope_scaling
        return getattr(config, "rope_scaling", None)

473
    class PositionRotaryEmbedding(nn.Module):
Nicolas Patry's avatar
Nicolas Patry committed
474
        def __init__(self, inv_freq, scaling_factor):
475
            super().__init__()
476
            self.inv_freq = inv_freq
477
478
479
480
481
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None
Nicolas Patry's avatar
Nicolas Patry committed
482
483
            self.scaling_factor = scaling_factor
            self.dynamic_args = None
484
485

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
486
487
488
489
490
491
492
493
494
495
496
497
498
        def static(cls, config, dim, base, device):
            inv_freq = _create_inv_freq(dim, base, device)
            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
                    return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
                else:
                    raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
            return cls(inv_freq, scaling_factor)
499
500

        @classmethod
Nicolas Patry's avatar
Nicolas Patry committed
501
        def load(cls, config, prefix, weights):
502
503
504
505
506
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
Nicolas Patry's avatar
Nicolas Patry committed
507
508
509
510
511
512
513
514
515
516
517
518

            scaling_factor = None
            rope_scaling = _get_rope_config(config)
            if rope_scaling is not None:
                scaling_factor = rope_scaling["factor"]
                if rope_scaling["type"] == "linear":
                    pass
                elif rope_scaling["type"] == "dynamic":
                    return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
                else:
                    raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
            return cls(inv_freq, scaling_factor)
519

520
521
522
523
524
525
526
527
528
529
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Nicolas Patry's avatar
Nicolas Patry committed
530
531
                if self.scaling_factor is not None:
                    t /= self.scaling_factor
532
533
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Nicolas Patry's avatar
Nicolas Patry committed
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
        ):
            """
            Return cos and sin for the asked position ids
            """

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
            return cos.unsqueeze(1), sin.unsqueeze(1)

552
        def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
553
            rotary_dim = cos.shape[-1]
554
555
556
557
558
            x1 = x[..., :rotary_dim]
            x2 = x[..., rotary_dim : 2 * rotary_dim]

            rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
            return x
559

Nicolas Patry's avatar
Nicolas Patry committed
560
561
    class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
        def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
Nicolas Patry's avatar
Nicolas Patry committed
562
            inv_freq = _create_inv_freq(dim, base, device)
Nicolas Patry's avatar
Nicolas Patry committed
563
564
565
566
567
568
569
570
571
572
573
574
575
576
            super().__init__(inv_freq, scaling_factor)
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base

        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                if seqlen > self.max_position_embeddings:
Nicolas Patry's avatar
Nicolas Patry committed
577
                    newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
Nicolas Patry's avatar
Nicolas Patry committed
578
579
580
581
582
583
584
585
586
587
588
                    self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)

                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)


589
590
except ImportError:
    pass