"vllm/vscode:/vscode.git/clone" did not exist on "fb0e0d46fc443f08bc2a859b839f0f66c6a7f670"
gguf.py 23.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from types import MappingProxyType
6
7
8
9
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
10
11
12

import gguf
import torch
13
from gguf import GGMLQuantizationType as WeightType
14
15
16
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
17
from vllm.logger import init_logger
18
19
20
21
from vllm.model_executor.layers.fused_moe.activation import (
    MoEActivation,
    apply_moe_activation,
)
22
23
24
25
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEQuantConfig,
)
26
27
28
29
from vllm.model_executor.layers.fused_moe.layer import (
    FusedMoE,
    FusedMoEMethodBase,
)
30
31
32
33
34
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
35
from vllm.model_executor.layers.quantization import QuantizationMethods
36
from vllm.model_executor.layers.quantization.base_config import (
37
38
39
    QuantizationConfig,
    QuantizeMethodBase,
)
40
41
42
43
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
    UnquantizedEmbeddingMethod,
    VocabParallelEmbedding,
)
from vllm.model_executor.models.utils import WeightsMapper
45
from vllm.model_executor.utils import set_weight_attrs
46
from vllm.platforms import current_platform
47
from vllm.utils.torch_utils import direct_register_custom_op
48

49
50
logger = init_logger(__name__)

51
52
53
54

class GGUFConfig(QuantizationConfig):
    """Config class for GGUF."""

55
    def __init__(self, unquantized_modules: list[str] | None = None) -> None:
56
        super().__init__()
57
        self.unquantized_modules = unquantized_modules or []
58
59

    def __repr__(self) -> str:
60
        return "GGUFConfig()"
61

62
    def get_name(self) -> QuantizationMethods:
63
64
        return "gguf"

65
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
66
67
68
69
70
        # GGUF dequantization kernels use half precision (fp16) internally.
        # bfloat16 has precision issues on Blackwell devices.
        if current_platform.has_device_capability(100):
            logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
            return [torch.half, torch.float32]
71
        return [torch.half, torch.bfloat16, torch.float32]
72
73
74
75
76
77

    @classmethod
    def get_min_capability(cls) -> int:
        return 60

    @classmethod
78
    def get_config_filenames(cls) -> list[str]:
79
80
81
        return []  # no extra configs.

    @classmethod
82
    def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
83
84
        return cls()

85
86
    @classmethod
    def override_quantization_method(
87
        cls, hf_quant_cfg: dict[str, Any], user_quant: str | None, hf_config=None
88
89
90
91
92
93
94
    ) -> "QuantizationMethods | None":
        # When user explicitly specifies --quantization gguf, override
        # whatever quantization method is in the HF model config (e.g. fp8).
        if user_quant == "gguf":
            return "gguf"
        return None

95
96
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
97
    ) -> "QuantizeMethodBase | None":
98
        if isinstance(layer, LinearBase):
99
100
101
            if is_layer_skipped_gguf(
                prefix, self.unquantized_modules, self.packed_modules_mapping
            ):
102
                return UnquantizedLinearMethod()
103
104
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
105
106
107
108
            if is_layer_skipped_gguf(
                prefix, self.unquantized_modules, self.packed_modules_mapping
            ):
                return UnquantizedEmbeddingMethod()
109
            return GGUFEmbeddingMethod(self)
110
        elif isinstance(layer, FusedMoE):
111
            # TODO: Select UnquantizedFusedMoEMethod on unquantized layers.
112
            return GGUFMoEMethod(self, layer.moe_config)
113
114
        return None

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        """
        Interface for models to update module names referenced in
        quantization configs in order to reflect the vllm model structure

        :param hf_to_vllm_mapper: maps from hf model structure (the assumed
            structure of the qconfig) to vllm model structure
        """
        if self.unquantized_modules is not None:
            self.unquantized_modules = hf_to_vllm_mapper.apply_list(
                self.unquantized_modules
            )


def is_layer_skipped_gguf(
    prefix: str,
    unquantized_modules: list[str],
    fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
):
    # Fused layers like gate_up_proj or qkv_proj will not be fused
    # in the safetensors checkpoint. So, we convert the name
    # from the fused version to unfused + check to make sure that
    # each shard of the fused layer has the same scheme.
    proj_name = prefix.split(".")[-1]
    if proj_name in fused_mapping:
        shard_prefixes = [
            prefix.replace(proj_name, shard_proj_name)
            for shard_proj_name in fused_mapping[proj_name]
        ]

        is_skipped = None
        for shard_prefix in shard_prefixes:
            is_shard_skipped = any(
                shard_prefix in module_name for module_name in unquantized_modules
            )

            if is_skipped is None:
                is_skipped = is_shard_skipped
            elif is_shard_skipped != is_skipped:
                raise ValueError(
                    f"Detected some but not all shards of {prefix} "
                    "are quantized. All shards of fused layers "
                    "to have the same precision."
                )
    else:
        is_skipped = any(module_name in prefix for module_name in unquantized_modules)

    assert is_skipped is not None
    return is_skipped
164
165


166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
    WeightType.Q4_0,
    WeightType.Q4_1,
    WeightType.Q5_0,
    WeightType.Q5_1,
    WeightType.Q8_0,
    WeightType.Q8_1,
}
KQUANT_TYPES = {
    WeightType.Q2_K,
    WeightType.Q3_K,
    WeightType.Q4_K,
    WeightType.Q5_K,
    WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
    WeightType.IQ1_M,
    WeightType.IQ1_S,
    WeightType.IQ2_XXS,
    WeightType.IQ2_XS,
    WeightType.IQ2_S,
    WeightType.IQ3_XXS,
    WeightType.IQ3_S,
    WeightType.IQ4_XS,
    WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


201
202
203
def _fused_mul_mat_gguf(
    x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
) -> torch.Tensor:
204
205
206
207
    if qweight_type in IMATRIX_QUANT_TYPES:
        mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
    else:
        mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
208
209
210
    # HACK: when doing chunked prefill we don't generate output tokens
    # so input to logits generator is empty which causes invalid parameter
    if x.shape[0] == 0:
211
        return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
212
213
214
215
    # there is no need to call any kernel for fp16/bf16
    if qweight_type in UNQUANTIZED_TYPES:
        return x @ qweight.T
    # enable MMVQ in contiguous batching with batch_size=1
216
    if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
217
        y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
218
219
220
221
222
    # Use MMQ Kernel if it's available (standard + k-quants)
    elif qweight_type in MMQ_QUANT_TYPES:
        y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
    # If there is no available MMQ kernel, fallback to dequantize
    elif qweight_type in DEQUANT_TYPES:
223
224
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
225
        weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
226
227
        y = x @ weight.T
    else:
228
229
230
231
        # Raise an error if the quantization type is not supported.
        # Might be useful if llama.cpp adds a new quantization type.
        # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
        qweight_type = WeightType(qweight_type)
232
        raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
233
234
235
    return y


236
237
238
239
240
def _fused_mul_mat_gguf_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
) -> torch.Tensor:
241
    return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
242
243
244
245
246
247
248
249
250
251
252
253
254
255


try:
    direct_register_custom_op(
        op_name="_fused_mul_mat_gguf",
        op_func=_fused_mul_mat_gguf,
        fake_impl=_fused_mul_mat_gguf_fake,
    )
    fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf

except AttributeError as error:
    raise error


256
257
258
259
260
261
262
263
def _fused_moe_gguf(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
264
    activation: str,
265
) -> torch.Tensor:
266
267
    activation_enum = MoEActivation.from_str(activation)

268
269
    def act(x: torch.Tensor):
        d = x.shape[-1] // 2
270
        output_shape = x.shape[:-1] + (d,)
271
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
272
        apply_moe_activation(activation_enum, out, x)
273
274
        return out

275
    # lazy import to avoid triggering triton import in CPU backend
276
    from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
277

278
    out_hidden_states = torch.empty_like(x)
279
    # unless we decent expert reuse we are better off running moe_vec kernel
280
281
282
283
284
    if (
        qweight_type2 in MMQ_QUANT_TYPES
        and qweight_type in MMQ_QUANT_TYPES
        and x.shape[0] > 64
    ):
285
286
287
288
289
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]
        BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)

290
291
292
293
294
295
296
297
298
299
300
301
302
303
        sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
            topk_ids, BLOCK_SIZE, E
        )
        out = ops.ggml_moe_a8(
            x,
            w1,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
            qweight_type,
            N,
            top_k,
            num_tokens,
        )
304
        out = act(out)
305
306
307
308
309
310
311
312
313
314
315
        out = ops.ggml_moe_a8(
            out,
            w2,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
            qweight_type2,
            w2.shape[1],
            1,
            num_tokens * top_k,
        )
316
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
317
318
            topk_weights.view(num_tokens, top_k, 1)
        )
319
        ops.moe_sum(out, out_hidden_states)
320
321
322
323
324
    elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]

325
        out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
326
327
        out = act(out)

328
329
330
        out = ops.ggml_moe_a8_vec(
            out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
        )
331
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
332
333
            topk_weights.view(num_tokens, top_k, 1)
        )
334
        ops.moe_sum(out, out_hidden_states)
335
    else:
336
337
338
339
340
        logger.warning_once(
            "There is no support for fast MoE kernel "
            "for current quantization method. "
            "Falling back to slow implementation. "
        )
341
        for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
342
            inp = x[tok].reshape((1,) + x.shape[1:])
343
344
345
346
            current_hidden_state = None
            for ww, ii in zip(w, idx):
                expert_up = w1[ii]

347
                out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
348
349
350
                out = act(out)

                expert_down = w2[ii]
351
352
353
                current_state = fused_mul_mat_gguf(
                    out, expert_down, qweight_type2
                ).mul_(ww)
354
355
356
357
358
359
360
361
                if current_hidden_state is None:
                    current_hidden_state = current_state
                else:
                    current_hidden_state.add_(current_state)
            out_hidden_states[tok] = current_hidden_state
    return out_hidden_states


362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def _fused_moe_gguf_fake(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
    activation: str,
) -> torch.Tensor:
    return torch.empty_like(x)


try:
    direct_register_custom_op(
        op_name="_fused_moe_gguf",
        op_func=_fused_moe_gguf,
        fake_impl=_fused_moe_gguf_fake,
    )
    fused_moe_gguf = torch.ops.vllm._fused_moe_gguf

except AttributeError as error:
    raise error


def _apply_gguf_embedding(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
392
    dtype: torch.dtype | None = None,
393
394
395
396
397
398
) -> torch.Tensor:
    if qweight_type in UNQUANTIZED_TYPES:
        return torch.embedding(qweight, x)
    elif qweight_type in DEQUANT_TYPES:
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        x_flat = x.flatten()
399
        assert hidden_size == qweight.shape[1] // type_size * block_size
400
        quant = torch.index_select(qweight, dim=0, index=x_flat)
401
402
403
        dequant = ops.ggml_dequantize(
            quant, qweight_type, hidden_size, x_flat.shape[0], dtype
        )
404
405
406
        return dequant.view(*x.shape, hidden_size)
    else:
        qweight_type = WeightType(qweight_type)
407
        raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
408
409
410
411
412
413
414


def _apply_gguf_embedding_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
415
    dtype: torch.dtype | None = None,
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
) -> torch.Tensor:
    return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)


try:
    direct_register_custom_op(
        op_name="_apply_gguf_embedding",
        op_func=_apply_gguf_embedding,
        fake_impl=_apply_gguf_embedding_fake,
    )
    apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

except AttributeError as error:
    raise error


432
433
434
435
436
437
438
439
440
441
class GGUFLinearMethod(LinearMethodBase):
    """Linear method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def __init__(self, quant_config: GGUFConfig):
        self.quant_config = quant_config

442
443
444
445
446
447
448
449
450
451
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
452
        self.params_dtype = params_dtype
453
454
455
        output_size_per_partition = sum(output_partition_sizes)

        tensor_shape = (output_size_per_partition, input_size_per_partition)
456
        qweight = GGUFUninitializedParameter(requires_grad=False)
457
        set_weight_attrs(
458
459
            qweight,
            {
460
461
462
463
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
464
                "data_container": [],
465
                "shard_id": [],
466
                "shard_id_map": {},
467
468
            },
        )
469
470
471
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("qweight", qweight)

472
473
474
475
        qweight_type = Parameter(
            torch.empty(len(output_partition_sizes), dtype=torch.uint8),
            requires_grad=False,
        )
476
        set_weight_attrs(
477
478
            qweight_type,
            {
479
480
481
                "is_gguf_weight_type": True,
                "weight_type": 0,
                "shard_weight_type": {},
482
483
484
                "ignore_warning": True,
            },
        )
485
486
487
        set_weight_attrs(qweight_type, extra_weight_attrs)
        layer.register_parameter("qweight_type", qweight_type)

488
489
    def process_weights_after_loading(self, layer: torch.nn.Module):
        qweight_type = layer.qweight_type.weight_type
490
        if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
491
492
            qweight_type = WeightType(qweight_type)
            raise ValueError(
493
494
                f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
            )
495
496
497
498
499
500
501
502
503
504
505
506
        # For MergedColumnParallelLinear and QKVParallelLinear, we need to
        # materialize the padded weight parameter for CUDA Graph compatibility.
        self._create_padded_weight_param(layer)

    def _create_padded_weight_param(self, layer: torch.nn.Module):
        """Create padded weight parameter for GGUF MergedLinear layer."""
        qweight = layer.qweight
        shard_id_map = qweight.shard_id_map
        shard_id = qweight.shard_id
        if len(data_container := qweight.data_container) > 1:
            dtype = {data.dtype for data in data_container}
            assert len(dtype) == 1, ValueError(
507
508
                f"Data container has mixed dtypes: {dtype}"
            )
509
510
511
512
513
514
            dtype = next(iter(dtype))
            # concat dim0 and pad dim1
            padded_side = max(x.size(1) for x in data_container)
            concat_side = sum(x.size(0) for x in data_container)
            # Pad the quantized weights to dense tensor, and create a map
            # with the location of each shard in the padded tensor.
515
516
517
            padded_data = torch.zeros(
                (concat_side, padded_side), dtype=dtype, device=qweight.device
            )
518
519
520
521
            # (dim0_start, dim0_end, dim1_size)
            shard_offset_map = dict[str, tuple[int, int, int]]()
            for idx in shard_id:
                id_in_container = shard_id_map[idx]
522
                start = sum(x.size(0) for x in data_container[:id_in_container])
523
524
525
526
527
528
529
                end = start + data_container[id_in_container].size(0)
                size = data_container[id_in_container].size(1)
                padded_data[start:end, :size] = data_container[id_in_container]
                shard_offset_map[idx] = (start, end, size)
            qweight.data_container.clear()
            padded_param = Parameter(padded_data, requires_grad=False)
            set_weight_attrs(padded_param, vars(qweight))
530
            set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
531
532
            layer.register_parameter("qweight", padded_param)

533
534
535
536
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
537
        bias: torch.Tensor | None = None,
538
    ) -> torch.Tensor:
539
        shard_id = layer.qweight.shard_id
540

541
        if shard_id:
542
543
            # dequantize shard weights respectively
            shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
544
            qweight = layer.qweight
545
            result = []
546
            for idx in shard_id:
547
                start, end, offset = layer.qweight.shard_offset_map[idx]
548
                qweight_type = layer.qweight_type.shard_weight_type[idx]
549
550
                result.append(
                    fused_mul_mat_gguf(
551
552
553
                        x, qweight[start:end, :offset].contiguous(), qweight_type
                    )
                )
554
555
556
557
            out = torch.cat(result, axis=1)
        else:
            qweight = layer.qweight
            qweight_type = layer.qweight_type.weight_type
558
            out = fused_mul_mat_gguf(x, qweight, qweight_type)
559
560
561
562
563
        if bias is not None:
            out.add_(bias)
        return out


564
565
566
567
568
569
570
class GGUFMoEMethod(FusedMoEMethodBase):
    """MoE method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

571
572
573
574
575
576
    def __init__(
        self,
        quant_config: GGUFConfig,
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
577
578
        self.quant_config = quant_config

579
580
581
582
583
584
585
586
587
588
589
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
        # gate up proj
590
591
        w13_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
592
593
            w13_qweight,
            {
594
595
596
597
598
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
599
600
            },
        )
601
602
603
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        layer.register_parameter("w13_qweight", w13_qweight)

604
605
606
607
608
609
610
        w13_qweight_type = Parameter(
            torch.empty(1, dtype=torch.uint8), requires_grad=False
        )
        set_weight_attrs(
            w13_qweight_type,
            {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
        )
611
612
613
        set_weight_attrs(w13_qweight_type, extra_weight_attrs)
        layer.register_parameter("w13_qweight_type", w13_qweight_type)

614
615
        tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
        # gate down proj
616
617
        w2_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
618
619
            w2_qweight,
            {
620
621
622
623
624
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
625
626
            },
        )
627
628
629
        set_weight_attrs(w2_qweight, extra_weight_attrs)
        layer.register_parameter("w2_qweight", w2_qweight)

630
631
632
633
634
635
636
        w2_qweight_type = Parameter(
            torch.empty(1, dtype=torch.uint8), requires_grad=False
        )
        set_weight_attrs(
            w2_qweight_type,
            {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
        )
637
638
639
640

        set_weight_attrs(w2_qweight_type, extra_weight_attrs)
        layer.register_parameter("w2_qweight_type", w2_qweight_type)

641
    def get_fused_moe_quant_config(
642
        self, layer: torch.nn.Module
643
    ) -> FusedMoEQuantConfig | None:
644
645
        return None

646
647
    def apply(
        self,
648
        layer: FusedMoE,
649
        x: torch.Tensor,
650
651
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
652
        shared_experts_input: torch.Tensor | None,
653
    ) -> torch.Tensor:
654
        if layer.apply_router_weight_on_input:
655
656
            raise NotImplementedError(
                "Apply router weight on input is not supported for"
657
658
                "fused GGUF MoE method."
            )
659

660
661
662
663
664
665
666
667
        return fused_moe_gguf(
            x,
            layer.w13_qweight,
            layer.w2_qweight,
            topk_weights,
            topk_ids,
            layer.w13_qweight_type.weight_type,
            layer.w2_qweight_type.weight_type,
668
            layer.activation.value,
669
        )
670
671


672
673
674
675
676
677
678
class GGUFEmbeddingMethod(GGUFLinearMethod):
    """Embedding method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

679
    def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
680
681
        qweight = layer.qweight
        qweight_type = layer.qweight_type.weight_type
682
        hidden_size = qweight.tensor_shape[1]
683

684
685
686
        return apply_gguf_embedding(
            x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
        )
687
688
689
690


class GGUFUninitializedParameter(UninitializedParameter):
    cls_to_become = Parameter
691
    data_container: list[torch.Tensor]