weights.py 29.4 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import os
2
from dataclasses import dataclass
3
from pathlib import Path
4
from typing import Dict, List, Optional, Tuple, Union
5
from safetensors import safe_open, SafetensorError
6
import torch
7
from loguru import logger
8
9
from huggingface_hub import hf_hub_download
import json
10
from text_generation_server.utils.log import log_once
11
12


13
14
15
16
17
18
19
20
21
@dataclass
class _GPTQParams:
    bits: int
    groupsize: int
    desc_act: bool
    quant_method: str
    sym: bool


22
class Weights:
23
24
25
26
27
28
29
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
30
        prefix: Optional[str] = None,
31
    ):
32
33
34
35
36
37
38
39
40
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
41
42
43
        if aliases is None:
            aliases = {}
        self.aliases = aliases
44
45
46
47
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
48
        self.prefix = prefix
49
50
51
52
53
54
55
56
57
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

58
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
59
60
61
62
63
64
65
66
67
68
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
69
70
71
72
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
73
        raise RuntimeError(f"weight {tensor_name} does not exist")
74
75

    def _get_slice(self, tensor_name: str):
76
        filename, tensor_name = self.get_filename(tensor_name)
77
78
79
80
81
82
83
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
84
    def get_tensor(self, tensor_name: str, to_device=True):
85
        filename, tensor_name = self.get_filename(tensor_name)
86
87
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
88
        # Special case for gptq which shouldn't convert
89
90
91
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
92
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
93
94
        if to_device:
            tensor = tensor.to(device=self.device)
95
96
        return tensor

97
    def get_partial_sharded(self, tensor_name: str, dim: int):
98
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
99
100
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
101
102
103
104
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
105
        block_size = (size + world_size - 1) // world_size
106
107
108
109
110
111
112
113
114
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
115
        # Special case for gptq which shouldn't convert
116
117
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
118
            tensor = tensor.to(dtype=self.dtype)
119
120
        tensor = tensor.to(device=self.device)
        return tensor
121

122
123
124
125
126
127
128
129
130
131
132
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

133
    def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
xiaobin's avatar
xiaobin committed
134
135
        slice_ = self._get_slice(name)
        total_size = slice_.get_shape()[1]
136
137
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

xiaobin's avatar
xiaobin committed
138
139
140
        world_size = self.process_group.size()
        rank = self.process_group.rank()

141
        weights = []
142
143
144
145
146
147
148
149
150
151
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
            ), f"Prepacked qkv cannot be sharded across {world_size} shards"
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
            weights.append(slice_[:, block_offset + start : block_offset + stop])
            block_offset += block_size
152
153

        weight = torch.cat(weights, dim=1)
xiaobin's avatar
xiaobin committed
154
155
156
        weight = weight.to(device=self.device)
        return weight

157
158
159
160
161
162
163
164
165
166
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        quantize: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
            prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
        )
Nicolas Patry's avatar
Nicolas Patry committed
167
168
169
170

    def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
        return self.get_weights_col_packed(prefix, quantize, 2)

171
172
173
    def get_weights_col_packed(
        self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
    ):
xiaobin's avatar
xiaobin committed
174
175
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
176
177
178
179
180
181
182
        already alternating Q,K,V within the main tensor.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
xiaobin's avatar
xiaobin committed
183
        """
184
        if quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
185
186
            from text_generation_server.layers.gptq import GPTQWeight

xiaobin's avatar
xiaobin committed
187
            try:
188
                qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
xiaobin's avatar
xiaobin committed
189
190
            except RuntimeError:
                raise RuntimeError(
191
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized."
xiaobin's avatar
xiaobin committed
192
193
                )

194
            gptq_params = self._get_gptq_params()
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
195

196
197
            qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
            scales = self._get_qweight(f"{prefix}.scales", block_sizes)
xiaobin's avatar
xiaobin committed
198
            scales = scales.to(dtype=self.dtype)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
199

200
            if quantize == "gptq" and gptq_params.quant_method == "gptq":
201
                g_idx = self.get_tensor(f"{prefix}.g_idx")
202
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
203
204
205
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
206
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
207
208
209
210
211
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                g_idx = (
212
213
214
215
216
                    torch.arange(
                        qweight.shape[0] * (32 // gptq_params.bits),
                        device=qweight.device,
                    )
                    // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
217
                ).to(dtype=torch.int32)
218
219
            else:
                g_idx = None
xiaobin's avatar
xiaobin committed
220

221
222
223
224
225
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
226
227
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
228
229
                use_exllama=False,
            )
230
        elif quantize == "marlin":
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            from text_generation_server.layers.marlin import (
                MarlinWeight,
                repack_gptq_for_marlin,
            )

            quant_method = getattr(self, "quant_method", "marlin")
            if quant_method == "gptq":
                gptq_params = self._get_gptq_params()
                try:
                    qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

                scales = self._get_qweight(f"{prefix}.scales", block_sizes)
                g_idx = self.get_tensor(f"{prefix}.g_idx")
                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )
258

259
260
261
262
            else:
                B = self._get_qweight(f"{prefix}.B", block_sizes)
                s = self._get_qweight(f"{prefix}.s", block_sizes)
                weight = MarlinWeight(B=B, s=s)
xiaobin's avatar
xiaobin committed
263
        else:
OlivierDehaene's avatar
OlivierDehaene committed
264
            slice_ = self._get_slice(f"{prefix}.weight")
xiaobin's avatar
xiaobin committed
265
            total_size = slice_.get_shape()[0]
266
267
268
269
            block_sizes = _blocks_to_block_sizes(
                total_size=total_size, blocks=block_sizes
            )

xiaobin's avatar
xiaobin committed
270
271
272
            world_size = self.process_group.size()
            rank = self.process_group.rank()

Nicolas Patry's avatar
Nicolas Patry committed
273
            tensors = []
274
275
276
277
278
279
280
281
282
            block_offset = 0
            for block_size in block_sizes:
                assert (
                    block_size % world_size == 0
                ), f"Prepacked weights cannot be sharded across {world_size} shards"
                shard_block_size = block_size // world_size
                start = rank * shard_block_size
                stop = (rank + 1) * shard_block_size
                tensor = slice_[block_offset + start : block_offset + stop]
Nicolas Patry's avatar
Nicolas Patry committed
283
                tensors.append(tensor)
284
                block_offset += block_size
Nicolas Patry's avatar
Nicolas Patry committed
285
            weight = torch.cat(tensors, dim=0)
xiaobin's avatar
xiaobin committed
286
287
288
289
            weight = weight.to(device=self.device)
            weight = weight.to(dtype=self.dtype)
        return weight

290
291
    def get_weights_col(self, prefix: str, quantize: str):
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
292
293
            from text_generation_server.layers.exl2 import Exl2Weight

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        return self.get_multi_weights_col([prefix], quantize, 0)

316
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
317
318
319
        if quantize == "exl2":
            raise ValueError("get_multi_weights_col is not supported for exl2")
        elif quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
320
321
            from text_generation_server.layers.gptq import GPTQWeight

322
            try:
323
324
325
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
326
            except RuntimeError:
327
                raise RuntimeError(
328
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
329
330
331
332
333
334
335
336
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
337

338
            gptq_params = self._get_gptq_params()
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
339

Nicolas Patry's avatar
Nicolas Patry committed
340
            from text_generation_server.layers.gptq import HAS_EXLLAMA
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
341
342

            use_exllama = (
343
344
345
346
                gptq_params.bits == 4
                and HAS_EXLLAMA
                and quantize == "gptq"
                and not gptq_params.desc_act
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
347
348
            )

349
            if quantize == "gptq" and gptq_params.quant_method == "gptq":
350
351
352
353
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
354
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
355
356
357
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
358
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
359
360
361
362
363
364
365
366
367
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
368
369
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
370
                        )
371
                        // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
372
                    ).to(dtype=torch.int32)
373
374
            else:
                g_idx = None
375

376
377
378
379
380
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
381
382
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
383
384
                use_exllama=use_exllama,
            )
385
        elif quantize == "marlin":
386
387
388
389
390
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                MarlinWeight,
                repack_gptq_for_marlin,
            )
391

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            quant_method = getattr(self, "quant_method", "marlin")
            if quant_method == "gptq":
                gptq_params = self._get_gptq_params()
                try:
                    qweight = torch.cat(
                        [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
                        dim=1,
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

                scales = torch.cat(
                    [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
407
                )
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]

                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )
            else:
                try:
                    B = torch.cat(
                        [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                    )
                s = torch.cat(
                    [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
434
435
                )

436
                weight = MarlinWeight(B=B, s=s)
437

438
439
440
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
441

442
        return weight
OlivierDehaene's avatar
OlivierDehaene committed
443

xiaobin's avatar
xiaobin committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
458
        return tensor
459
460

    def get_multi_weights_row(self, prefix: str, quantize: str):
461
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
462
463
            from text_generation_server.layers.exl2 import Exl2Weight

464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        elif quantize == "gptq":
485
            use_exllama = True
486
            gptq_params = self._get_gptq_params()
487

488
            if gptq_params.bits != 4:
489
490
                use_exllama = False

491
            if gptq_params.desc_act:
492
493
494
                log_once(logger.warning, "Disabling exllama because desc_act=True")
                use_exllama = False

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
495
496
497
498
499
500
501
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

502
            if gptq_params.quant_method == "gptq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
503
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
504
            elif gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
505
506
                g_idx = None

507
508
            if self.process_group.size() > 1:
                if g_idx is not None:
509
510
511
512
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
513
514
515
516
                                [
                                    i // gptq_params.groupsize
                                    for i in range(g_idx.shape[0])
                                ],
517
518
519
520
521
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
522
523
524
525
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

Nicolas Patry's avatar
Nicolas Patry committed
526
527
528
529
530
            from text_generation_server.layers.gptq import (
                HAS_EXLLAMA,
                CAN_EXLLAMA,
                GPTQWeight,
            )
531

532
            if use_exllama:
533
534
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
535
536
                        log_once(
                            logger.warning,
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
537
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
538
                        )
539
540
                    use_exllama = False
                else:
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
541
                    log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
542

543
            if use_exllama and gptq_params.groupsize != -1:
Nicolas Patry's avatar
Nicolas Patry committed
544
545
                qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                scales = self.get_sharded(f"{prefix}.scales", dim=0)
546
547
548
            else:
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
549

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
550
            if use_exllama and g_idx is not None:
551
                g_idx = g_idx - g_idx[0]
552

553
            if gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
554
555
556
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
557
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
558
559
560
561
562
563
564
565
566
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
567
568
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
569
                        )
570
                        // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
571
572
                    ).to(dtype=torch.int32)

573
574
575
576
577
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
578
579
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
580
581
                use_exllama=use_exllama,
            )
582
        elif quantize == "awq":
Nicolas Patry's avatar
Nicolas Patry committed
583
584
            from text_generation_server.layers.gptq import GPTQWeight

585
            gptq_params = self._get_gptq_params()
586
587
588
589
590
591
592
593
594
595
596
597

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `awq` weight, make sure the model is already quantized"
                )

            qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
            scales = self.get_sharded(f"{prefix}.scales", dim=0)
            g_idx = None
            use_exllama = False
OlivierDehaene's avatar
OlivierDehaene committed
598

599
600
601
602
603
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
604
605
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
606
607
                use_exllama=use_exllama,
            )
608
        elif quantize == "marlin":
609
610
611
612
613
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                MarlinWeight,
                repack_gptq_for_marlin,
            )
614

615
616
617
618
619
620
621
622
623
624
625
            quant_method = getattr(self, "quant_method", "marlin")
            if quant_method == "gptq":
                log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
                gptq_params = self._get_gptq_params()

                try:
                    qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )
626

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
                if gptq_params.desc_act or gptq_params.groupsize == -1:
                    scales = self.get_tensor(f"{prefix}.scales")
                else:
                    scales = self.get_sharded(f"{prefix}.scales", dim=0)

                sharded_in_features = self.process_group.size() > 1

                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=sharded_in_features,
                )
645
            else:
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
                try:
                    B = self.get_sharded(f"{prefix}.B", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                    )

                num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
                if num_groups == 1:
                    # The number of groups is 1 when groupsize == -1. share
                    # scales between all shards in this case.
                    s = self.get_tensor(f"{prefix}.s")
                else:
                    s = self.get_sharded(f"{prefix}.s", dim=0)
                weight = MarlinWeight(B=B, s=s)
661

662
663
664
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
665

666
    def _get_gptq_params(self) -> _GPTQParams:
667
668
669
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
670
            desc_act = False
671
            sym = True
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
672
            quant_method = "gptq"
673
674
        except (SafetensorError, RuntimeError) as e:
            try:
675
676
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
677
                desc_act = getattr(self, "gptq_desc_act", False)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
678
                quant_method = getattr(self, "quant_method", "gptq")
679
                sym = getattr(self, "sym", True)
680
681
682
            except Exception:
                raise e

683
684
685
686
687
688
689
        return _GPTQParams(
            bits=bits,
            desc_act=desc_act,
            groupsize=groupsize,
            quant_method=quant_method,
            sym=sym,
        )
690

OlivierDehaene's avatar
OlivierDehaene committed
691
    def _set_gptq_params(self, model_id, revision):
692
        filename = "config.json"
693
        try:
694
            if os.path.exists(os.path.join(model_id, filename)):
Nicolas Patry's avatar
Nicolas Patry committed
695
696
                filename = os.path.join(model_id, filename)
            else:
OlivierDehaene's avatar
OlivierDehaene committed
697
698
699
                filename = hf_hub_download(
                    model_id, filename=filename, revision=revision
                )
700
701
            with open(filename, "r") as f:
                data = json.load(f)
702
703
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
704
            # Order is important here, desc_act is missing on some real models
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
705
            self.quant_method = data["quantization_config"]["quant_method"]
706
            self.gptq_sym = data["quantization_config"]["sym"]
707
            self.gptq_desc_act = data["quantization_config"]["desc_act"]
708
        except Exception:
709
710
711
712
713
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
OlivierDehaene's avatar
OlivierDehaene committed
714
715
716
                    filename = hf_hub_download(
                        model_id, filename=filename, revision=revision
                    )
717
718
719
720
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
721
                self.gptq_sym = data["sym"]
722
                self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
723
724
                if "version" in data and data["version"] == "GEMM":
                    self.quant_method = "awq"
725
            except Exception:
726
727
728
729
730
                filename = "quant_config.json"
                try:
                    if os.path.exists(os.path.join(model_id, filename)):
                        filename = os.path.join(model_id, filename)
                    else:
OlivierDehaene's avatar
OlivierDehaene committed
731
732
733
                        filename = hf_hub_download(
                            model_id, filename=filename, revision=revision
                        )
734
735
736
737
                    with open(filename, "r") as f:
                        data = json.load(f)
                    self.gptq_bits = data["w_bit"]
                    self.gptq_groupsize = data["q_group_size"]
738
                    self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
739
740
                    if "version" in data and data["version"] == "GEMM":
                        self.quant_method = "awq"
741
742
                except Exception:
                    pass
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks