gguf.py 20.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Any, Callable, Optional
4
5
6

import gguf
import torch
7
from gguf import GGMLQuantizationType as WeightType
8
9
10
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
11
from vllm.logger import init_logger
12
13
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
                                                        FusedMoEMethodBase)
14
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
15
from vllm.model_executor.layers.quantization import QuantizationMethods
16
17
18
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
21
from vllm.utils import direct_register_custom_op
22

23
24
logger = init_logger(__name__)

25
26
27
28
29

class GGUFConfig(QuantizationConfig):
    """Config class for GGUF."""

    def __init__(self, ) -> None:
30
        super().__init__()
31
32
33
34

    def __repr__(self) -> str:
        return ("GGUFConfig()")

35
    def get_name(self) -> QuantizationMethods:
36
37
        return "gguf"

38
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
39
        return [torch.half, torch.bfloat16, torch.float32]
40
41
42
43
44
45

    @classmethod
    def get_min_capability(cls) -> int:
        return 60

    @classmethod
46
    def get_config_filenames(cls) -> list[str]:
47
48
49
        return []  # no extra configs.

    @classmethod
50
    def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
51
52
53
54
55
56
57
58
        return cls()

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        if isinstance(layer, LinearBase):
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
            return GGUFEmbeddingMethod(self)
59
60
        elif isinstance(layer, FusedMoE):
            return GGUFMoEMethod(self)
61
62
63
        return None


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
    WeightType.Q4_0,
    WeightType.Q4_1,
    WeightType.Q5_0,
    WeightType.Q5_1,
    WeightType.Q8_0,
    WeightType.Q8_1,
}
KQUANT_TYPES = {
    WeightType.Q2_K,
    WeightType.Q3_K,
    WeightType.Q4_K,
    WeightType.Q5_K,
    WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
    WeightType.IQ1_M,
    WeightType.IQ1_S,
    WeightType.IQ2_XXS,
    WeightType.IQ2_XS,
    WeightType.IQ2_S,
    WeightType.IQ3_XXS,
    WeightType.IQ3_S,
    WeightType.IQ4_XS,
    WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


99
100
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
                        qweight_type: int) -> torch.Tensor:
101
102
103
104
105
106
107
    # HACK: when doing chunked prefill we don't generate output tokens
    # so input to logits generator is empty which causes invalid parameter
    if x.shape[0] == 0:
        return torch.empty(x.shape[0],
                           qweight.shape[0],
                           dtype=x.dtype,
                           device=x.device)
108
109
110
111
112
    # there is no need to call any kernel for fp16/bf16
    if qweight_type in UNQUANTIZED_TYPES:
        return x @ qweight.T
    # enable MMVQ in contiguous batching with batch_size=1
    if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
113
        y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
114
115
116
117
118
    # Use MMQ Kernel if it's available (standard + k-quants)
    elif qweight_type in MMQ_QUANT_TYPES:
        y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
    # If there is no available MMQ kernel, fallback to dequantize
    elif qweight_type in DEQUANT_TYPES:
119
120
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
121
        weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
122
123
        y = x @ weight.T
    else:
124
125
126
127
128
129
        # Raise an error if the quantization type is not supported.
        # Might be useful if llama.cpp adds a new quantization type.
        # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
        qweight_type = WeightType(qweight_type)
        raise NotImplementedError(
            f"Unsupported GGUF quantization type: {qweight_type}")
130
131
132
    return y


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def _fused_mul_mat_gguf_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
) -> torch.Tensor:
    return torch.empty(x.shape[0],
                       qweight.shape[0],
                       dtype=x.dtype,
                       device=x.device)


try:
    direct_register_custom_op(
        op_name="_fused_mul_mat_gguf",
        op_func=_fused_mul_mat_gguf,
        mutates_args=[],
        fake_impl=_fused_mul_mat_gguf_fake,
    )
    fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf

except AttributeError as error:
    raise error


157
158
159
160
161
162
163
164
def _fused_moe_gguf(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
165
    activation: str,
166
) -> torch.Tensor:
167
168
169
170
171
172
173
174
175
176
177
178
179

    def act(x: torch.Tensor):
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if activation == "silu":
            torch.ops._C.silu_and_mul(out, x)
        elif activation == "gelu":
            torch.ops._C.gelu_and_mul(out, x)
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        return out

180
181
182
183
    # lazy import to avoid triggering triton import in CPU backend
    from vllm.model_executor.layers.fused_moe.fused_moe import (
        moe_align_block_size)

184
    out_hidden_states = torch.empty_like(x)
185
186
187
    # unless we decent expert reuse we are better off running moe_vec kernel
    if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES
            and x.shape[0] > 64):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]
        BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)

        sorted_token_ids, expert_ids, num_tokens_post_padded = \
                moe_align_block_size(topk_ids, BLOCK_SIZE, E)
        out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
                              num_tokens_post_padded, qweight_type, N, top_k,
                              num_tokens)
        out = act(out)
        out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
                              num_tokens_post_padded, qweight_type2,
                              w2.shape[1], 1, num_tokens * top_k)
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
            topk_weights.view(num_tokens, top_k, 1))
        ops.moe_sum(out, out_hidden_states)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]

        out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N,
                                  num_tokens)
        out = act(out)

        out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2,
                                  w2.shape[1], num_tokens * top_k)
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
            topk_weights.view(num_tokens, top_k, 1))
        ops.moe_sum(out, out_hidden_states)
219
220
221
222
223
224
225
226
227
228
    else:
        logger.warning_once("There is no support for fast MoE kernel "
                            "for current quantization method. "
                            "Falling back to slow implementation. ")
        for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
            inp = x[tok].reshape((1, ) + x.shape[1:])
            current_hidden_state = None
            for ww, ii in zip(w, idx):
                expert_up = w1[ii]

229
                out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
230
231
232
                out = act(out)

                expert_down = w2[ii]
233
234
                current_state = fused_mul_mat_gguf(out, expert_down,
                                                   qweight_type2).mul_(ww)
235
236
237
238
239
240
241
242
                if current_hidden_state is None:
                    current_hidden_state = current_state
                else:
                    current_hidden_state.add_(current_state)
            out_hidden_states[tok] = current_hidden_state
    return out_hidden_states


243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def _fused_moe_gguf_fake(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
    activation: str,
) -> torch.Tensor:
    return torch.empty_like(x)


try:
    direct_register_custom_op(
        op_name="_fused_moe_gguf",
        op_func=_fused_moe_gguf,
        mutates_args=[],
        fake_impl=_fused_moe_gguf_fake,
    )
    fused_moe_gguf = torch.ops.vllm._fused_moe_gguf

except AttributeError as error:
    raise error


def _apply_gguf_embedding(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    if qweight_type in UNQUANTIZED_TYPES:
        return torch.embedding(qweight, x)
    elif qweight_type in DEQUANT_TYPES:
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        x_flat = x.flatten()
        assert (hidden_size == qweight.shape[1] // type_size * block_size)
        quant = torch.index_select(qweight, dim=0, index=x_flat)
        dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
                                      x_flat.shape[0], dtype)
        return dequant.view(*x.shape, hidden_size)
    else:
        qweight_type = WeightType(qweight_type)
        raise NotImplementedError(
            f"Unsupported GGUF quantization type: {qweight_type}")


def _apply_gguf_embedding_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)


try:
    direct_register_custom_op(
        op_name="_apply_gguf_embedding",
        op_func=_apply_gguf_embedding,
        mutates_args=[],
        fake_impl=_apply_gguf_embedding_fake,
    )
    apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

except AttributeError as error:
    raise error


315
316
317
318
319
320
321
322
323
324
325
326
class GGUFLinearMethod(LinearMethodBase):
    """Linear method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def __init__(self, quant_config: GGUFConfig):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
327
                       output_partition_sizes: list[int], input_size: int,
328
329
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
330
        self.params_dtype = params_dtype
331
332
333
        output_size_per_partition = sum(output_partition_sizes)

        tensor_shape = (output_size_per_partition, input_size_per_partition)
334
        qweight = GGUFUninitializedParameter(requires_grad=False)
335
336
337
338
339
340
        set_weight_attrs(
            qweight, {
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
341
                "data_container": [],
342
                "shard_id": [],
343
                "shard_id_map": {},
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            })
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("qweight", qweight)

        qweight_type = Parameter(torch.empty(len(output_partition_sizes),
                                             dtype=torch.uint8),
                                 requires_grad=False)
        set_weight_attrs(
            qweight_type, {
                "is_gguf_weight_type": True,
                "weight_type": 0,
                "shard_weight_type": {},
                "ignore_warning": True
            })
        set_weight_attrs(qweight_type, extra_weight_attrs)
        layer.register_parameter("qweight_type", qweight_type)

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    def process_weights_after_loading(self, layer: torch.nn.Module):
        qweight_type = layer.qweight_type.weight_type
        if not (qweight_type in UNQUANTIZED_TYPES
                or qweight_type in DEQUANT_TYPES):
            qweight_type = WeightType(qweight_type)
            raise ValueError(
                f"Unsupported GGUF quantization type {qweight_type} in "
                f"layer {layer}.")
        # For MergedColumnParallelLinear and QKVParallelLinear, we need to
        # materialize the padded weight parameter for CUDA Graph compatibility.
        self._create_padded_weight_param(layer)

    def _create_padded_weight_param(self, layer: torch.nn.Module):
        """Create padded weight parameter for GGUF MergedLinear layer."""
        qweight = layer.qweight
        shard_id_map = qweight.shard_id_map
        shard_id = qweight.shard_id
        if len(data_container := qweight.data_container) > 1:
            dtype = {data.dtype for data in data_container}
            assert len(dtype) == 1, ValueError(
                f"Data container has mixed dtypes: {dtype}")
            dtype = next(iter(dtype))
            # concat dim0 and pad dim1
            padded_side = max(x.size(1) for x in data_container)
            concat_side = sum(x.size(0) for x in data_container)
            # Pad the quantized weights to dense tensor, and create a map
            # with the location of each shard in the padded tensor.
            padded_data = torch.zeros((concat_side, padded_side),
                                      dtype=dtype,
                                      device=qweight.device)
            # (dim0_start, dim0_end, dim1_size)
            shard_offset_map = dict[str, tuple[int, int, int]]()
            for idx in shard_id:
                id_in_container = shard_id_map[idx]
                start = sum(
                    x.size(0) for x in data_container[:id_in_container])
                end = start + data_container[id_in_container].size(0)
                size = data_container[id_in_container].size(1)
                padded_data[start:end, :size] = data_container[id_in_container]
                shard_offset_map[idx] = (start, end, size)
            qweight.data_container.clear()
            padded_param = Parameter(padded_data, requires_grad=False)
            set_weight_attrs(padded_param, vars(qweight))
            set_weight_attrs(padded_param,
                             {"shard_offset_map": shard_offset_map})
            layer.register_parameter("qweight", padded_param)

408
409
410
411
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
412
        shard_id = layer.qweight.shard_id
413

414
        if shard_id:
415
416
            # dequantize shard weights respectively
            shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
417
            qweight = layer.qweight
418
            result = []
419
            for idx in shard_id:
420
                start, end, offset = layer.qweight.shard_offset_map[idx]
421
                qweight_type = layer.qweight_type.shard_weight_type[idx]
422
423
424
425
                result.append(
                    fused_mul_mat_gguf(
                        x, qweight[start:end, :offset].contiguous(),
                        qweight_type))
426
427
428
429
            out = torch.cat(result, axis=1)
        else:
            qweight = layer.qweight
            qweight_type = layer.qweight_type.weight_type
430
            out = fused_mul_mat_gguf(x, qweight, qweight_type)
431
432
433
434
435
        if bias is not None:
            out.add_(bias)
        return out


436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
class GGUFMoEMethod(FusedMoEMethodBase):
    """MoE method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def __init__(self, quant_config: GGUFConfig):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
                        hidden_size)
        #gate up proj
        w13_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
            w13_qweight, {
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
            })
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        layer.register_parameter("w13_qweight", w13_qweight)

        w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
                                     requires_grad=False)
        set_weight_attrs(w13_qweight_type, {
            "is_gguf_weight_type": True,
            "weight_type": 0,
            "ignore_warning": True
        })
        set_weight_attrs(w13_qweight_type, extra_weight_attrs)
        layer.register_parameter("w13_qweight_type", w13_qweight_type)

        tensor_shape = (num_experts, intermediate_size_per_partition,
                        hidden_size)
        #gate down proj
        w2_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
            w2_qweight, {
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
            })
        set_weight_attrs(w2_qweight, extra_weight_attrs)
        layer.register_parameter("w2_qweight", w2_qweight)

        w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
                                    requires_grad=False)
        set_weight_attrs(w2_qweight_type, {
            "is_gguf_weight_type": True,
            "weight_type": 0,
            "ignore_warning": True
        })

        set_weight_attrs(w2_qweight_type, extra_weight_attrs)
        layer.register_parameter("w2_qweight_type", w2_qweight_type)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
516
        apply_router_weight_on_input: bool = False,
517
518
519
        activation: str = "silu",
    ):
        assert activation == "silu", "Only SiLU activation is supported."
520
521
522
523
524
        if apply_router_weight_on_input:
            raise NotImplementedError(
                "Apply router weight on input is not supported for"
                "fused GGUF MoE method.")

525
526
527
528
529
530
531
532
533
534
535
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)
536
537
538
539
        return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
                              topk_weights, topk_ids,
                              layer.w13_qweight_type.weight_type,
                              layer.w2_qweight_type.weight_type, activation)
540
541


542
543
544
545
546
547
548
549
550
551
552
class GGUFEmbeddingMethod(GGUFLinearMethod):
    """Embedding method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def embedding(self, layer: torch.nn.Module,
                  x: torch.Tensor) -> torch.Tensor:
        qweight = layer.qweight
        qweight_type = layer.qweight_type.weight_type
553
        hidden_size = qweight.tensor_shape[1]
554

555
556
557
558
559
        return apply_gguf_embedding(x,
                                    qweight,
                                    qweight_type,
                                    hidden_size,
                                    dtype=self.params_dtype)
560
561
562
563


class GGUFUninitializedParameter(UninitializedParameter):
    cls_to_become = Parameter
564
    data_container: list[torch.Tensor]