weights.py 32 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import os
2
from pathlib import Path
xuxzh1's avatar
last  
xuxzh1 committed
3
from typing import Dict, List, Optional, Union
4
from safetensors import safe_open, SafetensorError
5
import torch
6
from loguru import logger
7
8
from huggingface_hub import hf_hub_download
import json
xuxzh1's avatar
last  
xuxzh1 committed
9
from text_generation_server.layers.gptq import GPTQParams
10
from text_generation_server.utils.log import log_once
11
12
13


class Weights:
14
15
16
17
18
19
20
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
21
        prefix: Optional[str] = None,
22
    ):
23
24
25
26
27
28
29
30
31
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
32
33
34
        if aliases is None:
            aliases = {}
        self.aliases = aliases
35
36
37
38
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
39
        self.prefix = prefix
40
41
42
43
44
45
46
47
48
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

49
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
50
51
52
53
54
55
56
57
58
59
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
60
61
62
63
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
64
        raise RuntimeError(f"weight {tensor_name} does not exist")
65
66

    def _get_slice(self, tensor_name: str):
67
        filename, tensor_name = self.get_filename(tensor_name)
68
69
70
71
72
73
74
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
75
    def get_tensor(self, tensor_name: str, to_device=True):
76
        filename, tensor_name = self.get_filename(tensor_name)
77
78
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
79
        # Special case for gptq which shouldn't convert
xuxzh1's avatar
last  
xuxzh1 committed
80
81
82
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
83
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
84
85
        if to_device:
            tensor = tensor.to(device=self.device)
86
87
        return tensor

88
    def get_partial_sharded(self, tensor_name: str, dim: int):
89
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
90
91
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
92
93
94
95
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
96
        block_size = (size + world_size - 1) // world_size
97
98
99
100
101
102
103
104
105
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
106
        # Special case for gptq which shouldn't convert
xuxzh1's avatar
last  
xuxzh1 committed
107
108
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
109
            tensor = tensor.to(dtype=self.dtype)
110
111
        tensor = tensor.to(device=self.device)
        return tensor
112

113
114
115
116
117
118
119
120
121
122
123
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

xuxzh1's avatar
last  
xuxzh1 committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    def get_packed_sharded(
        self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
    ) -> torch.Tensor:
        """
        Get a shard from a tensor that packs multiple tensors.

        When a tensor packs multiple tensors (such as QKV or an up
        projection + gate projection), sharding with `get_sharded` is not
        safe since it would not split the packed tensors across shards.

        This method shards a tensor, such that the packed tensors are
        split across shards.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        slice_ = self._get_slice(tensor_name)
        total_size = slice_.get_shape()[dim]
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

xiaobin's avatar
xiaobin committed
147
148
149
        world_size = self.process_group.size()
        rank = self.process_group.rank()

xuxzh1's avatar
last  
xuxzh1 committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        tensors = []
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
            ), f"Prepacked tensor cannot be sharded across {world_size} shards"
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
            if dim == 0:
                tensor = slice_[block_offset + start : block_offset + stop]
            elif dim == 1:
                tensor = slice_[:, block_offset + start : block_offset + stop]
            else:
                raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
            tensors.append(tensor)
            block_offset += block_size
        tensor = torch.cat(tensors, dim=dim)
        tensor = tensor.to(device=self.device)

        # Avoid casting quantizer dtypes.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)

        return tensor
xiaobin's avatar
xiaobin committed
175

xuxzh1's avatar
last  
xuxzh1 committed
176
177
178
179
180
181
182
183
184
185
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        quantize: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
            prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
        )
Nicolas Patry's avatar
Nicolas Patry committed
186
187
188
189

    def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
        return self.get_weights_col_packed(prefix, quantize, 2)

xuxzh1's avatar
last  
xuxzh1 committed
190
191
192
    def get_weights_col_packed(
        self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
    ):
xiaobin's avatar
xiaobin committed
193
194
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
xuxzh1's avatar
last  
xuxzh1 committed
195
196
197
198
199
200
201
        already alternating Q,K,V within the main tensor.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
xiaobin's avatar
xiaobin committed
202
        """
203
        if quantize in ["gptq", "awq"]:
xuxzh1's avatar
last  
xuxzh1 committed
204
205
206
207
208
209
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                can_use_gptq_marlin,
                repack_gptq_for_marlin,
            )

xiaobin's avatar
xiaobin committed
210
            try:
xuxzh1's avatar
last  
xuxzh1 committed
211
212
213
                qweight = self.get_packed_sharded(
                    f"{prefix}.qweight", dim=1, block_sizes=block_sizes
                )
xiaobin's avatar
xiaobin committed
214
215
            except RuntimeError:
                raise RuntimeError(
216
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized."
xiaobin's avatar
xiaobin committed
217
                )
xuxzh1's avatar
last  
xuxzh1 committed
218
219
220
            scales = self.get_packed_sharded(
                f"{prefix}.scales", dim=1, block_sizes=block_sizes
            )
xiaobin's avatar
xiaobin committed
221
            scales = scales.to(dtype=self.dtype)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
222

xuxzh1's avatar
last  
xuxzh1 committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            gptq_params = self._get_gptq_params()
            if can_use_gptq_marlin(gptq_params, quantize):
                g_idx = self.get_tensor(f"{prefix}.g_idx")
                return repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )

            qzeros = self.get_packed_sharded(
                f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
            )
            if quantize == "gptq" and gptq_params.quant_method == "gptq":
241
                g_idx = self.get_tensor(f"{prefix}.g_idx")
xuxzh1's avatar
last  
xuxzh1 committed
242
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
243
244
245
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
xuxzh1's avatar
last  
xuxzh1 committed
246
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
247
248
249
250
251
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                g_idx = (
xuxzh1's avatar
last  
xuxzh1 committed
252
253
254
255
256
                    torch.arange(
                        qweight.shape[0] * (32 // gptq_params.bits),
                        device=qweight.device,
                    )
                    // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
257
                ).to(dtype=torch.int32)
258
259
            else:
                g_idx = None
xiaobin's avatar
xiaobin committed
260

xuxzh1's avatar
last  
xuxzh1 committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=False,
            )
        elif quantize == "marlin":
            from text_generation_server.layers.marlin import (
                GPTQMarlin24Weight,
                MarlinWeight,
                repack_gptq_for_marlin,
            )
xiaobin's avatar
xiaobin committed
276

xuxzh1's avatar
last  
xuxzh1 committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
            if is_marlin_24:
                B = self.get_packed_sharded(
                    f"{prefix}.B_24", dim=1, block_sizes=block_sizes
                )
                B_meta = self.get_packed_sharded(
                    f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
                )
                s = self.get_packed_sharded(
                    f"{prefix}.s", dim=1, block_sizes=block_sizes
                )

                gptq_params = self._get_gptq_params()
                weight = GPTQMarlin24Weight(
                    B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
                )
            else:
                B = self.get_packed_sharded(
                    f"{prefix}.B", dim=1, block_sizes=block_sizes
                )
                s = self.get_packed_sharded(
                    f"{prefix}.s", dim=1, block_sizes=block_sizes
                )
                weight = MarlinWeight(B=B, s=s)
        else:
            weight = self.get_packed_sharded(
                f"{prefix}.weight", dim=0, block_sizes=block_sizes
            )
xiaobin's avatar
xiaobin committed
305
306
        return weight

xuxzh1's avatar
last  
xuxzh1 committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    def get_weights_col(self, prefix: str, quantize: str):
        if quantize == "exl2":
            from text_generation_server.layers.exl2 import Exl2Weight

            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        return self.get_multi_weights_col([prefix], quantize, 0)

333
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
xuxzh1's avatar
last  
xuxzh1 committed
334
335
336
337
338
339
340
341
342
        if quantize == "exl2":
            raise ValueError("get_multi_weights_col is not supported for exl2")
        elif quantize in ["gptq", "awq"]:
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                can_use_gptq_marlin,
                repack_gptq_for_marlin,
            )

343
            try:
344
345
346
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
347
            except RuntimeError:
348
                raise RuntimeError(
349
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
350
351
352
353
354
                )

            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
355

xuxzh1's avatar
last  
xuxzh1 committed
356
357
358
359
360
361
            gptq_params = self._get_gptq_params()
            if can_use_gptq_marlin(gptq_params, quantize):
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
362

xuxzh1's avatar
last  
xuxzh1 committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
                return repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )

            from text_generation_server.layers.gptq import HAS_EXLLAMA
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
379
380

            use_exllama = (
xuxzh1's avatar
last  
xuxzh1 committed
381
382
383
384
                gptq_params.bits == 4
                and HAS_EXLLAMA
                and quantize == "gptq"
                and not gptq_params.desc_act
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
385
386
            )

xuxzh1's avatar
last  
xuxzh1 committed
387
            if quantize == "gptq" and gptq_params.quant_method == "gptq":
388
389
390
391
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
xuxzh1's avatar
last  
xuxzh1 committed
392
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
393
394
395
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
xuxzh1's avatar
last  
xuxzh1 committed
396
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
397
398
399
400
401
402
403
404
405
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
xuxzh1's avatar
last  
xuxzh1 committed
406
407
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
408
                        )
xuxzh1's avatar
last  
xuxzh1 committed
409
                        // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
410
                    ).to(dtype=torch.int32)
411
412
            else:
                g_idx = None
413

xuxzh1's avatar
last  
xuxzh1 committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=use_exllama,
            )
        elif quantize == "marlin":
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                GPTQMarlin24Weight,
                MarlinWeight,
            )

            is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
            if is_marlin_24:
                try:
                    B = torch.cat(
                        [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                    )

                B_meta = torch.cat(
                    [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
                )

                s = torch.cat(
                    [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
                )

                gptq_params = self._get_gptq_params()
                weight = GPTQMarlin24Weight(
                    B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
                )
            else:
                try:
                    B = torch.cat(
                        [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                    )
                s = torch.cat(
                    [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
                )

                weight = MarlinWeight(B=B, s=s)

468
469
470
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
xuxzh1's avatar
last  
xuxzh1 committed
471

472
        return weight
OlivierDehaene's avatar
OlivierDehaene committed
473

xiaobin's avatar
xiaobin committed
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
488
        return tensor
489
490

    def get_multi_weights_row(self, prefix: str, quantize: str):
xuxzh1's avatar
last  
xuxzh1 committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        if quantize == "exl2":
            from text_generation_server.layers.exl2 import Exl2Weight

            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        elif quantize == "gptq":
            from text_generation_server.layers.marlin import (
                can_use_gptq_marlin,
                repack_gptq_for_marlin,
            )
519

xuxzh1's avatar
last  
xuxzh1 committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
            gptq_params = self._get_gptq_params()
            if can_use_gptq_marlin(gptq_params, quantize):
                log_once(logger.info, "Using GPTQ-Marlin kernels")
                try:
                    qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
                if gptq_params.desc_act or gptq_params.groupsize == -1:
                    scales = self.get_tensor(f"{prefix}.scales")
                else:
                    scales = self.get_sharded(f"{prefix}.scales", dim=0)

                sharded_in_features = self.process_group.size() > 1

                return repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=sharded_in_features,
                )

            use_exllama = True
            if gptq_params.bits != 4:
551
552
                use_exllama = False

xuxzh1's avatar
last  
xuxzh1 committed
553
            if gptq_params.desc_act:
554
555
556
                log_once(logger.warning, "Disabling exllama because desc_act=True")
                use_exllama = False

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
557
558
559
560
561
562
563
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

xuxzh1's avatar
last  
xuxzh1 committed
564
            if gptq_params.quant_method == "gptq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
565
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
xuxzh1's avatar
last  
xuxzh1 committed
566
            elif gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
567
568
                g_idx = None

569
570
            if self.process_group.size() > 1:
                if g_idx is not None:
571
572
573
574
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
xuxzh1's avatar
last  
xuxzh1 committed
575
576
577
578
                                [
                                    i // gptq_params.groupsize
                                    for i in range(g_idx.shape[0])
                                ],
579
580
581
582
583
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
584
585
586
587
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

xuxzh1's avatar
last  
xuxzh1 committed
588
589
590
591
592
            from text_generation_server.layers.gptq import (
                HAS_EXLLAMA,
                CAN_EXLLAMA,
                GPTQWeight,
            )
593

594
            if use_exllama:
595
596
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
597
598
                        log_once(
                            logger.warning,
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
599
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
600
                        )
601
602
                    use_exllama = False
                else:
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
603
                    log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
604

xuxzh1's avatar
last  
xuxzh1 committed
605
            if use_exllama and gptq_params.groupsize != -1:
Nicolas Patry's avatar
Nicolas Patry committed
606
607
                qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                scales = self.get_sharded(f"{prefix}.scales", dim=0)
608
609
610
            else:
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
611

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
612
            if use_exllama and g_idx is not None:
613
                g_idx = g_idx - g_idx[0]
614

xuxzh1's avatar
last  
xuxzh1 committed
615
            if gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
616
617
618
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
xuxzh1's avatar
last  
xuxzh1 committed
619
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
620
621
622
623
624
625
626
627
628
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
xuxzh1's avatar
last  
xuxzh1 committed
629
630
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
631
                        )
xuxzh1's avatar
last  
xuxzh1 committed
632
                        // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
633
634
                    ).to(dtype=torch.int32)

xuxzh1's avatar
last  
xuxzh1 committed
635
636
637
638
639
640
641
642
643
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=use_exllama,
            )
644
        elif quantize == "awq":
xuxzh1's avatar
last  
xuxzh1 committed
645
646
647
            from text_generation_server.layers.gptq import GPTQWeight

            gptq_params = self._get_gptq_params()
648
649
650
651
652
653
654
655
656
657
658
659

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `awq` weight, make sure the model is already quantized"
                )

            qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
            scales = self.get_sharded(f"{prefix}.scales", dim=0)
            g_idx = None
            use_exllama = False
OlivierDehaene's avatar
OlivierDehaene committed
660

xuxzh1's avatar
last  
xuxzh1 committed
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
                use_exllama=use_exllama,
            )
        elif quantize == "marlin":
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                GPTQMarlin24Weight,
                MarlinWeight,
            )

            is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
            if is_marlin_24:
                try:
                    B = self.get_sharded(f"{prefix}.B_24", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
                    )

                B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
                num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
                if num_groups == 1:
                    # The number of groups is 1 when groupsize == -1. share
                    # scales between all shards in this case.
                    s = self.get_tensor(f"{prefix}.s")
                else:
                    s = self.get_sharded(f"{prefix}.s", dim=0)

                gptq_params = self._get_gptq_params()
                weight = GPTQMarlin24Weight(
                    B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
                )
            else:
                try:
                    B = self.get_sharded(f"{prefix}.B", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        "Cannot load `marlin` weight, make sure the model is already quantized."
                    )

                num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
                if num_groups == 1:
                    # The number of groups is 1 when groupsize == -1. share
                    # scales between all shards in this case.
                    s = self.get_tensor(f"{prefix}.s")
                else:
                    s = self.get_sharded(f"{prefix}.s", dim=0)
                weight = MarlinWeight(B=B, s=s)
715
716
717
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
718

xuxzh1's avatar
last  
xuxzh1 committed
719
    def _get_gptq_params(self) -> GPTQParams:
720
721
722
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
xuxzh1's avatar
last  
xuxzh1 committed
723
            checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
724
            desc_act = False
xuxzh1's avatar
last  
xuxzh1 committed
725
            sym = False
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
726
            quant_method = "gptq"
727
728
        except (SafetensorError, RuntimeError) as e:
            try:
729
730
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
xuxzh1's avatar
last  
xuxzh1 committed
731
                checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
732
                desc_act = getattr(self, "gptq_desc_act", False)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
733
                quant_method = getattr(self, "quant_method", "gptq")
xuxzh1's avatar
last  
xuxzh1 committed
734
                sym = getattr(self, "sym", True)
735
736
737
            except Exception:
                raise e

xuxzh1's avatar
last  
xuxzh1 committed
738
739
740
741
742
743
744
745
        return GPTQParams(
            bits=bits,
            checkpoint_format=checkpoint_format,
            desc_act=desc_act,
            groupsize=groupsize,
            quant_method=quant_method,
            sym=sym,
        )
746

OlivierDehaene's avatar
OlivierDehaene committed
747
    def _set_gptq_params(self, model_id, revision):
748
        filename = "config.json"
749
        try:
750
            if os.path.exists(os.path.join(model_id, filename)):
Nicolas Patry's avatar
Nicolas Patry committed
751
752
                filename = os.path.join(model_id, filename)
            else:
OlivierDehaene's avatar
OlivierDehaene committed
753
754
755
                filename = hf_hub_download(
                    model_id, filename=filename, revision=revision
                )
756
757
            with open(filename, "r") as f:
                data = json.load(f)
758
759
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
760
            # Order is important here, desc_act is missing on some real models
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
761
            self.quant_method = data["quantization_config"]["quant_method"]
xuxzh1's avatar
last  
xuxzh1 committed
762
763
764
765
            self.gptq_checkpoint_format = data["quantization_config"].get(
                "checkpoint_format"
            )
            self.gptq_sym = data["quantization_config"]["sym"]
766
            self.gptq_desc_act = data["quantization_config"]["desc_act"]
767
        except Exception:
768
769
770
771
772
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
OlivierDehaene's avatar
OlivierDehaene committed
773
774
775
                    filename = hf_hub_download(
                        model_id, filename=filename, revision=revision
                    )
776
777
778
779
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
xuxzh1's avatar
last  
xuxzh1 committed
780
                self.gptq_sym = data["sym"]
781
                self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
782
783
                if "version" in data and data["version"] == "GEMM":
                    self.quant_method = "awq"
784
            except Exception:
785
786
787
788
789
                filename = "quant_config.json"
                try:
                    if os.path.exists(os.path.join(model_id, filename)):
                        filename = os.path.join(model_id, filename)
                    else:
OlivierDehaene's avatar
OlivierDehaene committed
790
791
792
                        filename = hf_hub_download(
                            model_id, filename=filename, revision=revision
                        )
793
794
795
796
                    with open(filename, "r") as f:
                        data = json.load(f)
                    self.gptq_bits = data["w_bit"]
                    self.gptq_groupsize = data["q_group_size"]
797
                    self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
798
799
                    if "version" in data and data["version"] == "GEMM":
                        self.quant_method = "awq"
800
801
                except Exception:
                    pass
xuxzh1's avatar
last  
xuxzh1 committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks