weights.py 30.2 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
import os
2
from dataclasses import dataclass
3
from pathlib import Path
4
from typing import Dict, List, Optional, Tuple, Union
5
from safetensors import safe_open, SafetensorError
6
import torch
7
from loguru import logger
8
9
from huggingface_hub import hf_hub_download
import json
10
from text_generation_server.utils.log import log_once
11
12


13
14
15
16
17
18
19
20
21
@dataclass
class _GPTQParams:
    bits: int
    groupsize: int
    desc_act: bool
    quant_method: str
    sym: bool


22
class Weights:
23
24
25
26
27
28
29
    def __init__(
        self,
        filenames: List[Path],
        device,
        dtype,
        process_group,
        aliases: Optional[Dict[str, List[str]]] = None,
OlivierDehaene's avatar
OlivierDehaene committed
30
        prefix: Optional[str] = None,
31
    ):
32
33
34
35
36
37
38
39
40
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
41
42
43
        if aliases is None:
            aliases = {}
        self.aliases = aliases
44
45
46
47
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
Nicolas Patry's avatar
Nicolas Patry committed
48
        self.prefix = prefix
49
50
51
52
53
54
55
56
57
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

58
    def get_filename(self, tensor_name: str) -> (str, str):
Nicolas Patry's avatar
Nicolas Patry committed
59
60
61
62
63
64
65
66
67
68
        names = [tensor_name]
        if self.prefix is not None:
            prefixed = f"{self.prefix}.{tensor_name}"
            names.append(prefixed)
        for name in names:
            filename = self.routing.get(name, None)
            if filename is not None:
                return str(filename), name

            aliases = self.aliases.get(name, [])
69
70
71
72
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
Nicolas Patry's avatar
Nicolas Patry committed
73
        raise RuntimeError(f"weight {tensor_name} does not exist")
74
75

    def _get_slice(self, tensor_name: str):
76
        filename, tensor_name = self.get_filename(tensor_name)
77
78
79
80
81
82
83
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

OlivierDehaene's avatar
OlivierDehaene committed
84
    def get_tensor(self, tensor_name: str, to_device=True):
85
        filename, tensor_name = self.get_filename(tensor_name)
86
87
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
88
        # Special case for gptq which shouldn't convert
89
90
91
        # u4 which are disguised as int32. Exl2 uses int16
        # as well.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
92
            tensor = tensor.to(dtype=self.dtype)
xiaobin's avatar
xiaobin committed
93
94
        if to_device:
            tensor = tensor.to(device=self.device)
95
96
        return tensor

97
    def get_partial_sharded(self, tensor_name: str, dim: int):
98
        filename, tensor_name = self.get_filename(tensor_name)
xiaobin's avatar
xiaobin committed
99
100
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
101
102
103
104
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        size = slice_.get_shape()[dim]
105
        block_size = (size + world_size - 1) // world_size
106
107
108
109
110
111
112
113
114
        start = rank * block_size
        stop = (rank + 1) * block_size

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
115
        # Special case for gptq which shouldn't convert
116
117
        # u4 which are disguised as int32. exl2 uses int16.
        if tensor.dtype not in (torch.int16, torch.int32):
118
            tensor = tensor.to(dtype=self.dtype)
119
120
        tensor = tensor.to(device=self.device)
        return tensor
121

122
123
124
125
126
127
128
129
130
131
132
    def get_sharded(self, tensor_name: str, dim: int):
        filename, tensor_name = self.get_filename(tensor_name)
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        world_size = self.process_group.size()
        size = slice_.get_shape()[dim]
        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
        return self.get_partial_sharded(tensor_name, dim)

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    def get_packed_sharded(
        self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
    ) -> torch.Tensor:
        """
        Get a shard from a tensor that packs multiple tensors.

        When a tensor packs multiple tensors (such as QKV or an up
        projection + gate projection), sharding with `get_sharded` is not
        safe since it would not split the packed tensors across shards.

        This method shards a tensor, such that the packed tensors are
        split across shards.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
        """
        slice_ = self._get_slice(tensor_name)
        total_size = slice_.get_shape()[dim]
154
155
        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

xiaobin's avatar
xiaobin committed
156
157
158
        world_size = self.process_group.size()
        rank = self.process_group.rank()

159
        tensors = []
160
161
162
163
        block_offset = 0
        for block_size in block_sizes:
            assert (
                block_size % world_size == 0
164
            ), f"Prepacked tensor cannot be sharded across {world_size} shards"
165
166
167
            shard_block_size = block_size // world_size
            start = rank * shard_block_size
            stop = (rank + 1) * shard_block_size
168
169
170
171
172
173
174
            if dim == 0:
                tensor = slice_[block_offset + start : block_offset + stop]
            elif dim == 1:
                tensor = slice_[:, block_offset + start : block_offset + stop]
            else:
                raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
            tensors.append(tensor)
175
            block_offset += block_size
176
177
        tensor = torch.cat(tensors, dim=dim)
        tensor = tensor.to(device=self.device)
178

179
180
181
182
183
        # Avoid casting quantizer dtypes.
        if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
            tensor = tensor.to(dtype=self.dtype)

        return tensor
xiaobin's avatar
xiaobin committed
184

185
186
187
188
189
190
191
192
193
194
    def get_weights_col_packed_qkv(
        self,
        prefix: str,
        quantize: str,
        num_heads: int,
        num_key_value_heads: int,
    ):
        return self.get_weights_col_packed(
            prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
        )
Nicolas Patry's avatar
Nicolas Patry committed
195
196
197
198

    def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
        return self.get_weights_col_packed(prefix, quantize, 2)

199
200
201
    def get_weights_col_packed(
        self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
    ):
xiaobin's avatar
xiaobin committed
202
203
        """
        Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
204
205
206
207
208
209
210
        already alternating Q,K,V within the main tensor.

        The columns are split in equally sized blocks when blocks is an `int`, or
        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
        convenient for e.g. splitting QKV without knowing the storage details of
        quantized weights.
xiaobin's avatar
xiaobin committed
211
        """
212
        if quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
213
214
            from text_generation_server.layers.gptq import GPTQWeight

xiaobin's avatar
xiaobin committed
215
            try:
216
217
218
                qweight = self.get_packed_sharded(
                    f"{prefix}.qweight", dim=1, block_sizes=block_sizes
                )
xiaobin's avatar
xiaobin committed
219
220
            except RuntimeError:
                raise RuntimeError(
221
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized."
xiaobin's avatar
xiaobin committed
222
223
                )

224
            gptq_params = self._get_gptq_params()
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
225

226
227
228
229
230
231
            qzeros = self.get_packed_sharded(
                f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
            )
            scales = self.get_packed_sharded(
                f"{prefix}.scales", dim=1, block_sizes=block_sizes
            )
xiaobin's avatar
xiaobin committed
232
            scales = scales.to(dtype=self.dtype)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
233

234
            if quantize == "gptq" and gptq_params.quant_method == "gptq":
235
                g_idx = self.get_tensor(f"{prefix}.g_idx")
236
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
237
238
239
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
240
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
241
242
243
244
245
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                g_idx = (
246
247
248
249
250
                    torch.arange(
                        qweight.shape[0] * (32 // gptq_params.bits),
                        device=qweight.device,
                    )
                    // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
251
                ).to(dtype=torch.int32)
252
253
            else:
                g_idx = None
xiaobin's avatar
xiaobin committed
254

255
256
257
258
259
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
260
261
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
262
263
                use_exllama=False,
            )
264
        elif quantize == "marlin":
265
266
267
268
269
270
271
272
273
            from text_generation_server.layers.marlin import (
                MarlinWeight,
                repack_gptq_for_marlin,
            )

            quant_method = getattr(self, "quant_method", "marlin")
            if quant_method == "gptq":
                gptq_params = self._get_gptq_params()
                try:
274
275
276
                    qweight = self.get_packed_sharded(
                        f"{prefix}.qweight", dim=1, block_sizes=block_sizes
                    )
277
278
279
280
281
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

282
283
284
                scales = self.get_packed_sharded(
                    f"{prefix}.scales", dim=1, block_sizes=block_sizes
                )
285
286
287
288
289
290
291
292
293
294
295
                g_idx = self.get_tensor(f"{prefix}.g_idx")
                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )
296

297
            else:
298
299
300
301
302
303
                B = self.get_packed_sharded(
                    f"{prefix}.B", dim=1, block_sizes=block_sizes
                )
                s = self.get_packed_sharded(
                    f"{prefix}.s", dim=1, block_sizes=block_sizes
                )
304
                weight = MarlinWeight(B=B, s=s)
xiaobin's avatar
xiaobin committed
305
        else:
306
307
            weight = self.get_packed_sharded(
                f"{prefix}.weight", dim=0, block_sizes=block_sizes
308
            )
xiaobin's avatar
xiaobin committed
309
310
        return weight

311
312
    def get_weights_col(self, prefix: str, quantize: str):
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
313
314
            from text_generation_server.layers.exl2 import Exl2Weight

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        return self.get_multi_weights_col([prefix], quantize, 0)

337
    def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
338
339
340
        if quantize == "exl2":
            raise ValueError("get_multi_weights_col is not supported for exl2")
        elif quantize in ["gptq", "awq"]:
Nicolas Patry's avatar
Nicolas Patry committed
341
342
            from text_generation_server.layers.gptq import GPTQWeight

343
            try:
344
345
346
                qweight = torch.cat(
                    [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
                )
347
            except RuntimeError:
348
                raise RuntimeError(
349
                    f"Cannot load `{quantize}` weight, make sure the model is already quantized"
350
351
352
353
354
355
356
357
                )

            qzeros = torch.cat(
                [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
            )
            scales = torch.cat(
                [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
            )
358

359
            gptq_params = self._get_gptq_params()
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
360

Nicolas Patry's avatar
Nicolas Patry committed
361
            from text_generation_server.layers.gptq import HAS_EXLLAMA
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
362
363

            use_exllama = (
364
365
366
367
                gptq_params.bits == 4
                and HAS_EXLLAMA
                and quantize == "gptq"
                and not gptq_params.desc_act
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
368
369
            )

370
            if quantize == "gptq" and gptq_params.quant_method == "gptq":
371
372
373
374
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]
375
            elif quantize == "gptq" and gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
376
377
378
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
379
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
380
381
382
383
384
385
386
387
388
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
389
390
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
391
                        )
392
                        // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
393
                    ).to(dtype=torch.int32)
394
395
            else:
                g_idx = None
396

397
398
399
400
401
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
402
403
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
404
405
                use_exllama=use_exllama,
            )
406
        elif quantize == "marlin":
407
408
409
410
411
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                MarlinWeight,
                repack_gptq_for_marlin,
            )
412

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
            quant_method = getattr(self, "quant_method", "marlin")
            if quant_method == "gptq":
                gptq_params = self._get_gptq_params()
                try:
                    qweight = torch.cat(
                        [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
                        dim=1,
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )

                scales = torch.cat(
                    [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
428
                )
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
                w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
                for w2 in w[1:]:
                    torch.testing.assert_close(w2, w[0])
                g_idx = w[0]

                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=False,
                )
            else:
                try:
                    B = torch.cat(
                        [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
                    )
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight, make sure the model is already quantized"
                    )
                s = torch.cat(
                    [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
455
456
                )

457
                weight = MarlinWeight(B=B, s=s)
458

459
460
461
        else:
            w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
            weight = torch.cat(w, dim=dim)
462

463
        return weight
OlivierDehaene's avatar
OlivierDehaene committed
464

xiaobin's avatar
xiaobin committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    def get_tensor_shard(self, var, dim):
        world_size = self.process_group.size()
        rank = self.process_group.rank()
        block_size = var.size()[dim] // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        if dim == 0:
            tensor = var[start:stop]
        elif dim == 1:
            tensor = var[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
OlivierDehaene's avatar
OlivierDehaene committed
479
        return tensor
480
481

    def get_multi_weights_row(self, prefix: str, quantize: str):
482
        if quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
483
484
            from text_generation_server.layers.exl2 import Exl2Weight

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
            try:
                q_weight = self.get_tensor(f"{prefix}.q_weight")
            except RuntimeError:
                raise RuntimeError(
                    f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
                )

            q_scale = self.get_tensor(f"{prefix}.q_scale")
            q_invperm = self.get_tensor(f"{prefix}.q_invperm")
            q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
            q_groups = self.get_tensor(f"{prefix}.q_groups")

            return Exl2Weight(
                q_weight=q_weight,
                q_scale=q_scale,
                q_invperm=q_invperm,
                q_scale_max=q_scale_max,
                q_groups=q_groups,
            )

        elif quantize == "gptq":
506
            use_exllama = True
507
            gptq_params = self._get_gptq_params()
508

509
            if gptq_params.bits != 4:
510
511
                use_exllama = False

512
            if gptq_params.desc_act:
513
514
515
                log_once(logger.warning, "Disabling exllama because desc_act=True")
                use_exllama = False

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
516
517
518
519
520
521
522
            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                )

523
            if gptq_params.quant_method == "gptq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
524
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
525
            elif gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
526
527
                g_idx = None

528
529
            if self.process_group.size() > 1:
                if g_idx is not None:
530
531
532
533
                    if (
                        not torch.equal(
                            g_idx.cpu(),
                            torch.tensor(
534
535
536
537
                                [
                                    i // gptq_params.groupsize
                                    for i in range(g_idx.shape[0])
                                ],
538
539
540
541
542
                                dtype=torch.int32,
                            ),
                        )
                        and not (g_idx == 0).all()
                    ):
543
544
545
546
                        # Exllama implementation does not support row tensor parallelism with act-order, as
                        # it would require to reorder input activations that are split unto several GPUs
                        use_exllama = False

Nicolas Patry's avatar
Nicolas Patry committed
547
548
549
550
551
            from text_generation_server.layers.gptq import (
                HAS_EXLLAMA,
                CAN_EXLLAMA,
                GPTQWeight,
            )
552

553
            if use_exllama:
554
555
                if not HAS_EXLLAMA:
                    if CAN_EXLLAMA:
556
557
                        log_once(
                            logger.warning,
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
558
                            "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
559
                        )
560
561
                    use_exllama = False
                else:
OlivierDehaene's avatar
v1.3.4  
OlivierDehaene committed
562
                    log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
563

564
            if use_exllama and gptq_params.groupsize != -1:
Nicolas Patry's avatar
Nicolas Patry committed
565
566
                qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
                scales = self.get_sharded(f"{prefix}.scales", dim=0)
567
568
569
            else:
                qzeros = self.get_tensor(f"{prefix}.qzeros")
                scales = self.get_tensor(f"{prefix}.scales")
570

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
571
            if use_exllama and g_idx is not None:
572
                g_idx = g_idx - g_idx[0]
573

574
            if gptq_params.quant_method == "awq":
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
575
576
577
                log_once(
                    logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
                )
Nicolas Patry's avatar
Nicolas Patry committed
578
                from text_generation_server.layers.awq.conversion_utils import (
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
579
580
581
582
583
584
585
586
587
                    fast_awq_to_gptq,
                )

                qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
                if use_exllama:
                    g_idx = None
                else:
                    g_idx = (
                        torch.arange(
588
589
                            qweight.shape[0] * (32 // gptq_params.bits),
                            device=qweight.device,
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
590
                        )
591
                        // gptq_params.groupsize
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
592
593
                    ).to(dtype=torch.int32)

594
595
596
597
598
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
599
600
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
601
602
                use_exllama=use_exllama,
            )
603
        elif quantize == "awq":
Nicolas Patry's avatar
Nicolas Patry committed
604
605
            from text_generation_server.layers.gptq import GPTQWeight

606
            gptq_params = self._get_gptq_params()
607
608
609
610
611
612
613
614
615
616
617
618

            try:
                qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
            except RuntimeError:
                raise RuntimeError(
                    "Cannot load `awq` weight, make sure the model is already quantized"
                )

            qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
            scales = self.get_sharded(f"{prefix}.scales", dim=0)
            g_idx = None
            use_exllama = False
OlivierDehaene's avatar
OlivierDehaene committed
619

620
621
622
623
624
            weight = GPTQWeight(
                qweight=qweight,
                qzeros=qzeros,
                scales=scales,
                g_idx=g_idx,
625
626
                bits=gptq_params.bits,
                groupsize=gptq_params.groupsize,
627
628
                use_exllama=use_exllama,
            )
629
        elif quantize == "marlin":
630
631
632
633
634
            from text_generation_server.layers.gptq import GPTQWeight
            from text_generation_server.layers.marlin import (
                MarlinWeight,
                repack_gptq_for_marlin,
            )
635

636
637
638
639
640
641
642
643
644
645
646
            quant_method = getattr(self, "quant_method", "marlin")
            if quant_method == "gptq":
                log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
                gptq_params = self._get_gptq_params()

                try:
                    qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
                    )
647

648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
                g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
                if gptq_params.desc_act or gptq_params.groupsize == -1:
                    scales = self.get_tensor(f"{prefix}.scales")
                else:
                    scales = self.get_sharded(f"{prefix}.scales", dim=0)

                sharded_in_features = self.process_group.size() > 1

                weight = repack_gptq_for_marlin(
                    qweight=qweight,
                    scales=scales,
                    g_idx=g_idx,
                    bits=gptq_params.bits,
                    desc_act=gptq_params.desc_act,
                    groupsize=gptq_params.groupsize,
                    sym=gptq_params.sym,
                    sharded_infeatures=sharded_in_features,
                )
666
            else:
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
                try:
                    B = self.get_sharded(f"{prefix}.B", dim=0)
                except RuntimeError:
                    raise RuntimeError(
                        "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
                    )

                num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
                if num_groups == 1:
                    # The number of groups is 1 when groupsize == -1. share
                    # scales between all shards in this case.
                    s = self.get_tensor(f"{prefix}.s")
                else:
                    s = self.get_sharded(f"{prefix}.s", dim=0)
                weight = MarlinWeight(B=B, s=s)
682

683
684
685
        else:
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weight
686

687
    def _get_gptq_params(self) -> _GPTQParams:
688
689
690
        try:
            bits = self.get_tensor("gptq_bits").item()
            groupsize = self.get_tensor("gptq_groupsize").item()
691
            desc_act = False
692
            sym = True
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
693
            quant_method = "gptq"
694
695
        except (SafetensorError, RuntimeError) as e:
            try:
696
697
                bits = self.gptq_bits
                groupsize = self.gptq_groupsize
698
                desc_act = getattr(self, "gptq_desc_act", False)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
699
                quant_method = getattr(self, "quant_method", "gptq")
700
                sym = getattr(self, "sym", True)
701
702
703
            except Exception:
                raise e

704
705
706
707
708
709
710
        return _GPTQParams(
            bits=bits,
            desc_act=desc_act,
            groupsize=groupsize,
            quant_method=quant_method,
            sym=sym,
        )
711

OlivierDehaene's avatar
OlivierDehaene committed
712
    def _set_gptq_params(self, model_id, revision):
713
        filename = "config.json"
714
        try:
715
            if os.path.exists(os.path.join(model_id, filename)):
Nicolas Patry's avatar
Nicolas Patry committed
716
717
                filename = os.path.join(model_id, filename)
            else:
OlivierDehaene's avatar
OlivierDehaene committed
718
719
720
                filename = hf_hub_download(
                    model_id, filename=filename, revision=revision
                )
721
722
            with open(filename, "r") as f:
                data = json.load(f)
723
724
            self.gptq_bits = data["quantization_config"]["bits"]
            self.gptq_groupsize = data["quantization_config"]["group_size"]
725
            # Order is important here, desc_act is missing on some real models
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
726
            self.quant_method = data["quantization_config"]["quant_method"]
727
            self.gptq_sym = data["quantization_config"]["sym"]
728
            self.gptq_desc_act = data["quantization_config"]["desc_act"]
729
        except Exception:
730
731
732
733
734
            filename = "quantize_config.json"
            try:
                if os.path.exists(os.path.join(model_id, filename)):
                    filename = os.path.join(model_id, filename)
                else:
OlivierDehaene's avatar
OlivierDehaene committed
735
736
737
                    filename = hf_hub_download(
                        model_id, filename=filename, revision=revision
                    )
738
739
740
741
                with open(filename, "r") as f:
                    data = json.load(f)
                self.gptq_bits = data["bits"]
                self.gptq_groupsize = data["group_size"]
742
                self.gptq_sym = data["sym"]
743
                self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
744
745
                if "version" in data and data["version"] == "GEMM":
                    self.quant_method = "awq"
746
            except Exception:
747
748
749
750
751
                filename = "quant_config.json"
                try:
                    if os.path.exists(os.path.join(model_id, filename)):
                        filename = os.path.join(model_id, filename)
                    else:
OlivierDehaene's avatar
OlivierDehaene committed
752
753
754
                        filename = hf_hub_download(
                            model_id, filename=filename, revision=revision
                        )
755
756
757
758
                    with open(filename, "r") as f:
                        data = json.load(f)
                    self.gptq_bits = data["w_bit"]
                    self.gptq_groupsize = data["q_group_size"]
759
                    self.gptq_desc_act = data["desc_act"]
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
760
761
                    if "version" in data and data["version"] == "GEMM":
                        self.quant_method = "awq"
762
763
                except Exception:
                    pass
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (List[int]).

    In the latter case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert (
            total_size % total_blocks == 0
        ), f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks