gguf.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from types import MappingProxyType
6
from typing import Any, Optional
7
8
9

import gguf
import torch
10
from gguf import GGMLQuantizationType as WeightType
11
12
13
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
14
from vllm.logger import init_logger
15
16
17
18
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEQuantConfig,
)
19
20
21
22
23
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
    FusedMoE,
    FusedMoEMethodBase,
)
24
25
26
27
28
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
29
from vllm.model_executor.layers.quantization import QuantizationMethods
30
from vllm.model_executor.layers.quantization.base_config import (
31
32
33
    QuantizationConfig,
    QuantizeMethodBase,
)
34
35
36
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    UnquantizedEmbeddingMethod,
    VocabParallelEmbedding,
)
from vllm.model_executor.models.utils import WeightsMapper
39
from vllm.model_executor.utils import set_weight_attrs
40
from vllm.platforms import current_platform
41
from vllm.utils.torch_utils import direct_register_custom_op
42

43
44
logger = init_logger(__name__)

45
46
47
48

class GGUFConfig(QuantizationConfig):
    """Config class for GGUF."""

49
    def __init__(self, unquantized_modules: list[str] | None = None) -> None:
50
        super().__init__()
51
        self.unquantized_modules = unquantized_modules or []
52
53

    def __repr__(self) -> str:
54
        return "GGUFConfig()"
55

56
    def get_name(self) -> QuantizationMethods:
57
58
        return "gguf"

59
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
60
61
62
63
64
        # GGUF dequantization kernels use half precision (fp16) internally.
        # bfloat16 has precision issues on Blackwell devices.
        if current_platform.has_device_capability(100):
            logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
            return [torch.half, torch.float32]
65
        return [torch.half, torch.bfloat16, torch.float32]
66
67
68
69
70
71

    @classmethod
    def get_min_capability(cls) -> int:
        return 60

    @classmethod
72
    def get_config_filenames(cls) -> list[str]:
73
74
75
        return []  # no extra configs.

    @classmethod
76
    def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
77
78
        return cls()

79
80
81
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
82
        if isinstance(layer, LinearBase):
83
84
85
            if is_layer_skipped_gguf(
                prefix, self.unquantized_modules, self.packed_modules_mapping
            ):
86
                return UnquantizedLinearMethod()
87
88
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
89
90
91
92
            if is_layer_skipped_gguf(
                prefix, self.unquantized_modules, self.packed_modules_mapping
            ):
                return UnquantizedEmbeddingMethod()
93
            return GGUFEmbeddingMethod(self)
94
        elif isinstance(layer, FusedMoE):
95
            # TODO: Select UnquantizedFusedMoEMethod on unquantized layers.
96
            return GGUFMoEMethod(self, layer.moe_config)
97
98
        return None

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        """
        Interface for models to update module names referenced in
        quantization configs in order to reflect the vllm model structure

        :param hf_to_vllm_mapper: maps from hf model structure (the assumed
            structure of the qconfig) to vllm model structure
        """
        if self.unquantized_modules is not None:
            self.unquantized_modules = hf_to_vllm_mapper.apply_list(
                self.unquantized_modules
            )


def is_layer_skipped_gguf(
    prefix: str,
    unquantized_modules: list[str],
    fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
):
    # Fused layers like gate_up_proj or qkv_proj will not be fused
    # in the safetensors checkpoint. So, we convert the name
    # from the fused version to unfused + check to make sure that
    # each shard of the fused layer has the same scheme.
    proj_name = prefix.split(".")[-1]
    if proj_name in fused_mapping:
        shard_prefixes = [
            prefix.replace(proj_name, shard_proj_name)
            for shard_proj_name in fused_mapping[proj_name]
        ]

        is_skipped = None
        for shard_prefix in shard_prefixes:
            is_shard_skipped = any(
                shard_prefix in module_name for module_name in unquantized_modules
            )

            if is_skipped is None:
                is_skipped = is_shard_skipped
            elif is_shard_skipped != is_skipped:
                raise ValueError(
                    f"Detected some but not all shards of {prefix} "
                    "are quantized. All shards of fused layers "
                    "to have the same precision."
                )
    else:
        is_skipped = any(module_name in prefix for module_name in unquantized_modules)

    assert is_skipped is not None
    return is_skipped
148
149


150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
    WeightType.Q4_0,
    WeightType.Q4_1,
    WeightType.Q5_0,
    WeightType.Q5_1,
    WeightType.Q8_0,
    WeightType.Q8_1,
}
KQUANT_TYPES = {
    WeightType.Q2_K,
    WeightType.Q3_K,
    WeightType.Q4_K,
    WeightType.Q5_K,
    WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
    WeightType.IQ1_M,
    WeightType.IQ1_S,
    WeightType.IQ2_XXS,
    WeightType.IQ2_XS,
    WeightType.IQ2_S,
    WeightType.IQ3_XXS,
    WeightType.IQ3_S,
    WeightType.IQ4_XS,
    WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


185
186
187
def _fused_mul_mat_gguf(
    x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
) -> torch.Tensor:
188
189
190
191
    if qweight_type in IMATRIX_QUANT_TYPES:
        mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
    else:
        mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
192
193
194
    # HACK: when doing chunked prefill we don't generate output tokens
    # so input to logits generator is empty which causes invalid parameter
    if x.shape[0] == 0:
195
        return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
196
197
198
199
    # there is no need to call any kernel for fp16/bf16
    if qweight_type in UNQUANTIZED_TYPES:
        return x @ qweight.T
    # enable MMVQ in contiguous batching with batch_size=1
200
    if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
201
        y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
202
203
204
205
206
    # Use MMQ Kernel if it's available (standard + k-quants)
    elif qweight_type in MMQ_QUANT_TYPES:
        y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
    # If there is no available MMQ kernel, fallback to dequantize
    elif qweight_type in DEQUANT_TYPES:
207
208
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
209
        weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
210
211
        y = x @ weight.T
    else:
212
213
214
215
        # Raise an error if the quantization type is not supported.
        # Might be useful if llama.cpp adds a new quantization type.
        # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
        qweight_type = WeightType(qweight_type)
216
        raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
217
218
219
    return y


220
221
222
223
224
def _fused_mul_mat_gguf_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
) -> torch.Tensor:
225
    return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
226
227
228
229
230
231
232
233
234
235
236
237
238
239


try:
    direct_register_custom_op(
        op_name="_fused_mul_mat_gguf",
        op_func=_fused_mul_mat_gguf,
        fake_impl=_fused_mul_mat_gguf_fake,
    )
    fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf

except AttributeError as error:
    raise error


240
241
242
243
244
245
246
247
def _fused_moe_gguf(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
248
    activation: str,
249
) -> torch.Tensor:
250
251
    def act(x: torch.Tensor):
        d = x.shape[-1] // 2
252
        output_shape = x.shape[:-1] + (d,)
253
254
255
256
257
258
259
260
261
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if activation == "silu":
            torch.ops._C.silu_and_mul(out, x)
        elif activation == "gelu":
            torch.ops._C.gelu_and_mul(out, x)
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        return out

262
    # lazy import to avoid triggering triton import in CPU backend
263
    from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
264

265
    out_hidden_states = torch.empty_like(x)
266
    # unless we decent expert reuse we are better off running moe_vec kernel
267
268
269
270
271
    if (
        qweight_type2 in MMQ_QUANT_TYPES
        and qweight_type in MMQ_QUANT_TYPES
        and x.shape[0] > 64
    ):
272
273
274
275
276
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]
        BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)

277
278
279
280
281
282
283
284
285
286
287
288
289
290
        sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
            topk_ids, BLOCK_SIZE, E
        )
        out = ops.ggml_moe_a8(
            x,
            w1,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
            qweight_type,
            N,
            top_k,
            num_tokens,
        )
291
        out = act(out)
292
293
294
295
296
297
298
299
300
301
302
        out = ops.ggml_moe_a8(
            out,
            w2,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
            qweight_type2,
            w2.shape[1],
            1,
            num_tokens * top_k,
        )
303
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
304
305
            topk_weights.view(num_tokens, top_k, 1)
        )
306
        ops.moe_sum(out, out_hidden_states)
307
308
309
310
311
    elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]

312
        out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
313
314
        out = act(out)

315
316
317
        out = ops.ggml_moe_a8_vec(
            out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
        )
318
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
319
320
            topk_weights.view(num_tokens, top_k, 1)
        )
321
        ops.moe_sum(out, out_hidden_states)
322
    else:
323
324
325
326
327
        logger.warning_once(
            "There is no support for fast MoE kernel "
            "for current quantization method. "
            "Falling back to slow implementation. "
        )
328
        for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
329
            inp = x[tok].reshape((1,) + x.shape[1:])
330
331
332
333
            current_hidden_state = None
            for ww, ii in zip(w, idx):
                expert_up = w1[ii]

334
                out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
335
336
337
                out = act(out)

                expert_down = w2[ii]
338
339
340
                current_state = fused_mul_mat_gguf(
                    out, expert_down, qweight_type2
                ).mul_(ww)
341
342
343
344
345
346
347
348
                if current_hidden_state is None:
                    current_hidden_state = current_state
                else:
                    current_hidden_state.add_(current_state)
            out_hidden_states[tok] = current_hidden_state
    return out_hidden_states


349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def _fused_moe_gguf_fake(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
    activation: str,
) -> torch.Tensor:
    return torch.empty_like(x)


try:
    direct_register_custom_op(
        op_name="_fused_moe_gguf",
        op_func=_fused_moe_gguf,
        fake_impl=_fused_moe_gguf_fake,
    )
    fused_moe_gguf = torch.ops.vllm._fused_moe_gguf

except AttributeError as error:
    raise error


def _apply_gguf_embedding(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
379
    dtype: torch.dtype | None = None,
380
381
382
383
384
385
) -> torch.Tensor:
    if qweight_type in UNQUANTIZED_TYPES:
        return torch.embedding(qweight, x)
    elif qweight_type in DEQUANT_TYPES:
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        x_flat = x.flatten()
386
        assert hidden_size == qweight.shape[1] // type_size * block_size
387
        quant = torch.index_select(qweight, dim=0, index=x_flat)
388
389
390
        dequant = ops.ggml_dequantize(
            quant, qweight_type, hidden_size, x_flat.shape[0], dtype
        )
391
392
393
        return dequant.view(*x.shape, hidden_size)
    else:
        qweight_type = WeightType(qweight_type)
394
        raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
395
396
397
398
399
400
401


def _apply_gguf_embedding_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
402
    dtype: torch.dtype | None = None,
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
) -> torch.Tensor:
    return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)


try:
    direct_register_custom_op(
        op_name="_apply_gguf_embedding",
        op_func=_apply_gguf_embedding,
        fake_impl=_apply_gguf_embedding_fake,
    )
    apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

except AttributeError as error:
    raise error


419
420
421
422
423
424
425
426
427
428
class GGUFLinearMethod(LinearMethodBase):
    """Linear method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def __init__(self, quant_config: GGUFConfig):
        self.quant_config = quant_config

429
430
431
432
433
434
435
436
437
438
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
439
        self.params_dtype = params_dtype
440
441
442
        output_size_per_partition = sum(output_partition_sizes)

        tensor_shape = (output_size_per_partition, input_size_per_partition)
443
        qweight = GGUFUninitializedParameter(requires_grad=False)
444
        set_weight_attrs(
445
446
            qweight,
            {
447
448
449
450
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
451
                "data_container": [],
452
                "shard_id": [],
453
                "shard_id_map": {},
454
455
            },
        )
456
457
458
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("qweight", qweight)

459
460
461
462
        qweight_type = Parameter(
            torch.empty(len(output_partition_sizes), dtype=torch.uint8),
            requires_grad=False,
        )
463
        set_weight_attrs(
464
465
            qweight_type,
            {
466
467
468
                "is_gguf_weight_type": True,
                "weight_type": 0,
                "shard_weight_type": {},
469
470
471
                "ignore_warning": True,
            },
        )
472
473
474
        set_weight_attrs(qweight_type, extra_weight_attrs)
        layer.register_parameter("qweight_type", qweight_type)

475
476
    def process_weights_after_loading(self, layer: torch.nn.Module):
        qweight_type = layer.qweight_type.weight_type
477
        if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
478
479
            qweight_type = WeightType(qweight_type)
            raise ValueError(
480
481
                f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
            )
482
483
484
485
486
487
488
489
490
491
492
493
        # For MergedColumnParallelLinear and QKVParallelLinear, we need to
        # materialize the padded weight parameter for CUDA Graph compatibility.
        self._create_padded_weight_param(layer)

    def _create_padded_weight_param(self, layer: torch.nn.Module):
        """Create padded weight parameter for GGUF MergedLinear layer."""
        qweight = layer.qweight
        shard_id_map = qweight.shard_id_map
        shard_id = qweight.shard_id
        if len(data_container := qweight.data_container) > 1:
            dtype = {data.dtype for data in data_container}
            assert len(dtype) == 1, ValueError(
494
495
                f"Data container has mixed dtypes: {dtype}"
            )
496
497
498
499
500
501
            dtype = next(iter(dtype))
            # concat dim0 and pad dim1
            padded_side = max(x.size(1) for x in data_container)
            concat_side = sum(x.size(0) for x in data_container)
            # Pad the quantized weights to dense tensor, and create a map
            # with the location of each shard in the padded tensor.
502
503
504
            padded_data = torch.zeros(
                (concat_side, padded_side), dtype=dtype, device=qweight.device
            )
505
506
507
508
            # (dim0_start, dim0_end, dim1_size)
            shard_offset_map = dict[str, tuple[int, int, int]]()
            for idx in shard_id:
                id_in_container = shard_id_map[idx]
509
                start = sum(x.size(0) for x in data_container[:id_in_container])
510
511
512
513
514
515
516
                end = start + data_container[id_in_container].size(0)
                size = data_container[id_in_container].size(1)
                padded_data[start:end, :size] = data_container[id_in_container]
                shard_offset_map[idx] = (start, end, size)
            qweight.data_container.clear()
            padded_param = Parameter(padded_data, requires_grad=False)
            set_weight_attrs(padded_param, vars(qweight))
517
            set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
518
519
            layer.register_parameter("qweight", padded_param)

520
521
522
523
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
524
        bias: torch.Tensor | None = None,
525
    ) -> torch.Tensor:
526
        shard_id = layer.qweight.shard_id
527

528
        if shard_id:
529
530
            # dequantize shard weights respectively
            shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
531
            qweight = layer.qweight
532
            result = []
533
            for idx in shard_id:
534
                start, end, offset = layer.qweight.shard_offset_map[idx]
535
                qweight_type = layer.qweight_type.shard_weight_type[idx]
536
537
                result.append(
                    fused_mul_mat_gguf(
538
539
540
                        x, qweight[start:end, :offset].contiguous(), qweight_type
                    )
                )
541
542
543
544
            out = torch.cat(result, axis=1)
        else:
            qweight = layer.qweight
            qweight_type = layer.qweight_type.weight_type
545
            out = fused_mul_mat_gguf(x, qweight, qweight_type)
546
547
548
549
550
        if bias is not None:
            out.add_(bias)
        return out


551
552
553
554
555
556
557
class GGUFMoEMethod(FusedMoEMethodBase):
    """MoE method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

558
559
560
561
562
563
    def __init__(
        self,
        quant_config: GGUFConfig,
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
564
565
        self.quant_config = quant_config

566
567
568
569
570
571
572
573
574
575
576
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
        # gate up proj
577
578
        w13_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
579
580
            w13_qweight,
            {
581
582
583
584
585
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
586
587
            },
        )
588
589
590
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        layer.register_parameter("w13_qweight", w13_qweight)

591
592
593
594
595
596
597
        w13_qweight_type = Parameter(
            torch.empty(1, dtype=torch.uint8), requires_grad=False
        )
        set_weight_attrs(
            w13_qweight_type,
            {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
        )
598
599
600
        set_weight_attrs(w13_qweight_type, extra_weight_attrs)
        layer.register_parameter("w13_qweight_type", w13_qweight_type)

601
602
        tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
        # gate down proj
603
604
        w2_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
605
606
            w2_qweight,
            {
607
608
609
610
611
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
612
613
            },
        )
614
615
616
        set_weight_attrs(w2_qweight, extra_weight_attrs)
        layer.register_parameter("w2_qweight", w2_qweight)

617
618
619
620
621
622
623
        w2_qweight_type = Parameter(
            torch.empty(1, dtype=torch.uint8), requires_grad=False
        )
        set_weight_attrs(
            w2_qweight_type,
            {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
        )
624
625
626
627

        set_weight_attrs(w2_qweight_type, extra_weight_attrs)
        layer.register_parameter("w2_qweight_type", w2_qweight_type)

628
    def get_fused_moe_quant_config(
629
        self, layer: torch.nn.Module
630
    ) -> FusedMoEQuantConfig | None:
631
632
        return None

633
634
    def apply(
        self,
635
        layer: FusedMoE,
636
        router: FusedMoERouter,
637
638
        x: torch.Tensor,
        router_logits: torch.Tensor,
639
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
640
641
        assert layer.activation == "silu", "Only SiLU activation is supported."
        if layer.apply_router_weight_on_input:
642
643
            raise NotImplementedError(
                "Apply router weight on input is not supported for"
644
645
                "fused GGUF MoE method."
            )
646

647
        topk_weights, topk_ids = router.select_experts(
648
649
            hidden_states=x,
            router_logits=router_logits,
650
651
652
653
654
655
656
657
658
        )
        return fused_moe_gguf(
            x,
            layer.w13_qweight,
            layer.w2_qweight,
            topk_weights,
            topk_ids,
            layer.w13_qweight_type.weight_type,
            layer.w2_qweight_type.weight_type,
659
            layer.activation,
660
        )
661
662


663
664
665
666
667
668
669
class GGUFEmbeddingMethod(GGUFLinearMethod):
    """Embedding method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

670
    def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
671
672
        qweight = layer.qweight
        qweight_type = layer.qweight_type.weight_type
673
        hidden_size = qweight.tensor_shape[1]
674

675
676
677
        return apply_gguf_embedding(
            x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
        )
678
679
680
681


class GGUFUninitializedParameter(UninitializedParameter):
    cls_to_become = Parameter
682
    data_container: list[torch.Tensor]