layers.py 13.1 KB
Newer Older
1
import torch
2
import torch.distributed
3
4

from torch import nn
5
from torch.nn import functional as F
6
from typing import List
7
8
9

HAS_BITS_AND_BYTES = True
try:
10
11
12
13
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params

except ImportError:
14
15
    HAS_BITS_AND_BYTES = False

16
17
from accelerate import init_empty_weights

18
19
from text_generation_server.utils.gptq.quant_linear import QuantLinear

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = nn.Parameter(weight)
    ln.bias = nn.Parameter(bias)
    return ln


torch.nn.LayerNorm.load = load_layer_norm
35

36
37

class FastLinear(nn.Module):
38
39
    def __init__(
        self,
40
41
        weight,
        bias,
42
    ) -> None:
43
44
45
46
47
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
48
            self.bias = None
49
50
51
52
53
54

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
55
        else:
56
57
            bias = None
        return cls(weight, bias)
58
59

    def forward(self, input: torch.Tensor) -> torch.Tensor:
60
        return F.linear(input, self.weight, self.bias)
61
62


63
class Linear8bitLt(nn.Module):
64
65
    def __init__(
        self,
66
67
68
69
70
71
        weight,
        bias,
        has_fp16_weights=True,
        memory_efficient_backward=False,
        threshold=0.0,
        index=None,
72
    ):
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        super().__init__()
        assert (
            not memory_efficient_backward
        ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
        self.state = bnb.MatmulLtState()
        self.index = index

        # Necessary for stacked layers
        self.state.threshold = threshold
        self.state.has_fp16_weights = has_fp16_weights
        self.state.memory_efficient_backward = memory_efficient_backward
        if threshold > 0.0 and not has_fp16_weights:
            self.state.use_pool = True

        self.weight = Int8Params(
            weight.data,
            has_fp16_weights=has_fp16_weights,
            requires_grad=has_fp16_weights,
91
        )
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self.weight.cuda(weight.device)
        self.bias = bias

    def init_8bit_state(self):
        self.state.CB = self.weight.CB
        self.state.SCB = self.weight.SCB
        self.weight.CB = None
        self.weight.SCB = None

    def forward(self, x: torch.Tensor):
        self.state.is_training = self.training
        if self.weight.CB is not None:
            self.init_8bit_state()

        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

        if not self.state.has_fp16_weights:
            if self.state.CB is not None and self.state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format in the first inference pass
                # we no longer need the row-major weight
                del self.state.CB
                self.weight.data = self.state.CxB
        return out
119
120


121
122
123
124
125
126
127
128
129
130
131
132
133
def get_linear(weight, bias, quantize):
    if quantize is None:
        linear = FastLinear(weight, bias)
    elif quantize == "bitsandbytes":
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
    elif quantize == "gptq":
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        try:
            qweight, qzeros, scales, g_idx, bits, groupsize = weight
        except Exception:
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

        linear = QuantLinear(
            qweight,
            qzeros,
            scales,
            g_idx,
            bias,
            bits,
            groupsize,
        )
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear


class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear.forward(x)


class TensorParallelHead(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
167
        self.process_group = process_group
168
169
170
171

    @staticmethod
    def load(config, prefix: str, weights):
        weight = weights.get_sharded(f"{prefix}.weight", dim=0)
172
173
174
175
176
177

        # GPTQ doesn't quantize heads (nor embeddings)
        if config.quantize == "gptq":
            quantize = None
        else:
            quantize = config.quantize
178
        return TensorParallelHead(
179
            get_linear(weight, bias=None, quantize=quantize),
180
            process_group=weights.process_group,
181
182
183
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        world_size = self.process_group.size()
        if world_size == 1:
            return super().forward(input)

        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
            out_dim = self.linear.weight.shape[0]

            if input.shape[0] == 1:
                world_out = input.new_empty(1, out_dim * world_size)
                local_out = input.new_empty(1, out_dim)
                gather_input = local_out
            else:
                world_out = input.new_empty(out_dim * world_size, input.shape[0])
                gather_input = input.new_empty(out_dim, input.shape[0])
                local_out = gather_input.T

            torch.mm(input, self.linear.weight.T, out=local_out)

            torch.distributed.all_gather_into_tensor(
                world_out, gather_input, group=self.process_group
            )

            if input.shape[0] == 1:
                return world_out
            return world_out.T

210
211
212
213
214
215
216
217
218
219
220
221
        output = super().forward(input)
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)
        return world_output


class TensorParallelColumnLinear(SuperLayer):
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
222
        return cls.load_multi(config, [prefix], weights, bias, dim=0)
223

224
225
    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
226
227
228
        weight = weights.get_multi_weights_col(
            prefixes, quantize=config.quantize, dim=dim
        )
229

230
231
        if bias:
            b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
232
            bias = torch.cat(b, dim=dim)
233
234
        else:
            bias = None
235
236
        linear = get_linear(weight, bias, config.quantize)
        return cls(linear)
237

238
239
240
241

class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
242
243
        self.process_group = process_group

244
245
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
246
247
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

248
249
250
251
252
253
254
255
256
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )
257

258
259
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = super().forward(input)
260
261
        if self.process_group.size() > 1:
            torch.distributed.all_reduce(out, group=self.process_group)
262
        return out
263
264


265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class TensorParallelEmbedding(nn.Module):
    def __init__(self, prefix: str, weights, reduce=True):
        super().__init__()
        weight = weights.get_sharded(f"{prefix}.weight", dim=0)
        num_embeddings = weights.get_shape(f"{prefix}.weight")[0]

        process_group = weights.process_group

        world_size = process_group.size()
        rank = process_group.rank()

        block_size = num_embeddings // world_size
        self.min_id = rank * block_size
        self.max_id = min(num_embeddings, (rank + 1) * block_size)
        self.null_idx = block_size
        self.process_group = weights.process_group
        self.reduce = reduce
282
283

        """Additional 0 entry used for masking"""
284
        self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
285
286
287
288
289
290
291
292
293

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # default all out of bounds values to `self.null_idx` that will then be mapped to 0
        # translate for [0, self.max_id - self.min_id[
        input = torch.where(
            (self.min_id > input) | (input >= self.max_id),
            self.null_idx,
            input - self.min_id,
        )
294
        out = torch.nn.functional.embedding(input, self.weight)
295
        if self.reduce and self.process_group.size() > 1:
296
            torch.distributed.all_reduce(out, group=self.process_group)
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        return out


try:
    import dropout_layer_norm

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual

except ImportError:
    pass


try:
    from flash_attn.layers.rotary import RotaryEmbedding
    import rotary_emb

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    class PositionRotaryEmbedding(nn.Module):
        def __init__(self, inv_freq):
            super().__init__()

            self.register_buffer("inv_freq", inv_freq)
            self._seq_len_cached = 0
            self._cos_cached = None
            self._sin_cached = None
            self._cos_k_cached = None
            self._sin_k_cached = None

        @classmethod
        def static(cls, dim, base, device):
            inv_freq = 1.0 / (
                base
                ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
            )
            return cls(inv_freq)

        @classmethod
        def load(cls, prefix, weights):
            # XXX: Always load this in float32 !
            dtype = weights.dtype
            weights.dtype = torch.float32
            inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
            weights.dtype = dtype
            return cls(inv_freq)

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        def _update_cos_sin_cache(self, dtype, device, seqlen):
            # Reset the tables if the sequence length has changed,
            # or if we're on a new device (possibly due to tracing for instance)
            if (
                seqlen > self._seq_len_cached
                or self._cos_cached.device != device
                or self._cos_cached.dtype != dtype
            ):
                self._seq_len_cached = seqlen
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # Don't do einsum, it converts fp32 to fp16
                # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
                freqs = torch.outer(t, self.inv_freq.to(device=t.device))
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

        def get_cos_sin(
            self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
        ):
            """
            Return cos and sin for the asked position ids
            """

            self._update_cos_sin_cache(dtype, position_ids.device, max_s)

            cos = torch.index_select(self._cos_cached, 0, position_ids)
            sin = torch.index_select(self._sin_cached, 0, position_ids)
            return cos.unsqueeze(1), sin.unsqueeze(1)

403
        def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
404
            rotary_dim = cos.shape[-1]
405
406
407
408
409
            x1 = x[..., :rotary_dim]
            x2 = x[..., rotary_dim : 2 * rotary_dim]

            rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
            return x
410
411
412

except ImportError:
    pass