gguf.py 21.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import Any, Optional
6
7
8

import gguf
import torch
9
from gguf import GGMLQuantizationType as WeightType
10
11
12
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
13
from vllm.logger import init_logger
14
15
16
17
18
19
20
21
22
23
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
24
from vllm.model_executor.layers.quantization import QuantizationMethods
25
from vllm.model_executor.layers.quantization.base_config import (
26
27
28
29
    QuantizationConfig,
    QuantizeMethodBase,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
30
from vllm.model_executor.utils import set_weight_attrs
31
from vllm.utils.torch_utils import direct_register_custom_op
32

33
34
logger = init_logger(__name__)

35
36
37
38

class GGUFConfig(QuantizationConfig):
    """Config class for GGUF."""

39
    def __init__(self, unquantized_modules: list[str] | None = None) -> None:
40
        super().__init__()
41
        self.unquantized_modules = unquantized_modules or []
42
43

    def __repr__(self) -> str:
44
        return "GGUFConfig()"
45

46
    def get_name(self) -> QuantizationMethods:
47
48
        return "gguf"

49
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
50
        return [torch.half, torch.bfloat16, torch.float32]
51
52
53
54
55
56

    @classmethod
    def get_min_capability(cls) -> int:
        return 60

    @classmethod
57
    def get_config_filenames(cls) -> list[str]:
58
59
60
        return []  # no extra configs.

    @classmethod
61
    def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
62
63
        return cls()

64
65
66
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
67
        if isinstance(layer, LinearBase):
68
69
            if is_layer_skipped_gguf(prefix, self.unquantized_modules):
                return UnquantizedLinearMethod()
70
71
72
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
            return GGUFEmbeddingMethod(self)
73
        elif isinstance(layer, FusedMoE):
74
            return GGUFMoEMethod(self, layer.moe_config)
75
76
77
        return None


78
79
80
81
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
    return any(module_name in prefix for module_name in unquantized_modules)


82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
    WeightType.Q4_0,
    WeightType.Q4_1,
    WeightType.Q5_0,
    WeightType.Q5_1,
    WeightType.Q8_0,
    WeightType.Q8_1,
}
KQUANT_TYPES = {
    WeightType.Q2_K,
    WeightType.Q3_K,
    WeightType.Q4_K,
    WeightType.Q5_K,
    WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
    WeightType.IQ1_M,
    WeightType.IQ1_S,
    WeightType.IQ2_XXS,
    WeightType.IQ2_XS,
    WeightType.IQ2_S,
    WeightType.IQ3_XXS,
    WeightType.IQ3_S,
    WeightType.IQ4_XS,
    WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


117
118
119
def _fused_mul_mat_gguf(
    x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
) -> torch.Tensor:
120
121
122
123
    if qweight_type in IMATRIX_QUANT_TYPES:
        mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
    else:
        mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
124
125
126
    # HACK: when doing chunked prefill we don't generate output tokens
    # so input to logits generator is empty which causes invalid parameter
    if x.shape[0] == 0:
127
        return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
128
129
130
131
    # there is no need to call any kernel for fp16/bf16
    if qweight_type in UNQUANTIZED_TYPES:
        return x @ qweight.T
    # enable MMVQ in contiguous batching with batch_size=1
132
    if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
133
        y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
134
135
136
137
138
    # Use MMQ Kernel if it's available (standard + k-quants)
    elif qweight_type in MMQ_QUANT_TYPES:
        y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
    # If there is no available MMQ kernel, fallback to dequantize
    elif qweight_type in DEQUANT_TYPES:
139
140
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
141
        weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
142
143
        y = x @ weight.T
    else:
144
145
146
147
        # Raise an error if the quantization type is not supported.
        # Might be useful if llama.cpp adds a new quantization type.
        # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
        qweight_type = WeightType(qweight_type)
148
        raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
149
150
151
    return y


152
153
154
155
156
def _fused_mul_mat_gguf_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
) -> torch.Tensor:
157
    return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
158
159
160
161
162
163
164
165
166
167
168
169
170
171


try:
    direct_register_custom_op(
        op_name="_fused_mul_mat_gguf",
        op_func=_fused_mul_mat_gguf,
        fake_impl=_fused_mul_mat_gguf_fake,
    )
    fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf

except AttributeError as error:
    raise error


172
173
174
175
176
177
178
179
def _fused_moe_gguf(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
180
    activation: str,
181
) -> torch.Tensor:
182
183
    def act(x: torch.Tensor):
        d = x.shape[-1] // 2
184
        output_shape = x.shape[:-1] + (d,)
185
186
187
188
189
190
191
192
193
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if activation == "silu":
            torch.ops._C.silu_and_mul(out, x)
        elif activation == "gelu":
            torch.ops._C.gelu_and_mul(out, x)
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        return out

194
    # lazy import to avoid triggering triton import in CPU backend
195
    from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
196

197
    out_hidden_states = torch.empty_like(x)
198
    # unless we decent expert reuse we are better off running moe_vec kernel
199
200
201
202
203
    if (
        qweight_type2 in MMQ_QUANT_TYPES
        and qweight_type in MMQ_QUANT_TYPES
        and x.shape[0] > 64
    ):
204
205
206
207
208
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]
        BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)

209
210
211
212
213
214
215
216
217
218
219
220
221
222
        sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
            topk_ids, BLOCK_SIZE, E
        )
        out = ops.ggml_moe_a8(
            x,
            w1,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
            qweight_type,
            N,
            top_k,
            num_tokens,
        )
223
        out = act(out)
224
225
226
227
228
229
230
231
232
233
234
        out = ops.ggml_moe_a8(
            out,
            w2,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
            qweight_type2,
            w2.shape[1],
            1,
            num_tokens * top_k,
        )
235
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
236
237
            topk_weights.view(num_tokens, top_k, 1)
        )
238
        ops.moe_sum(out, out_hidden_states)
239
240
241
242
243
    elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]

244
        out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
245
246
        out = act(out)

247
248
249
        out = ops.ggml_moe_a8_vec(
            out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
        )
250
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
251
252
            topk_weights.view(num_tokens, top_k, 1)
        )
253
        ops.moe_sum(out, out_hidden_states)
254
    else:
255
256
257
258
259
        logger.warning_once(
            "There is no support for fast MoE kernel "
            "for current quantization method. "
            "Falling back to slow implementation. "
        )
260
        for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
261
            inp = x[tok].reshape((1,) + x.shape[1:])
262
263
264
265
            current_hidden_state = None
            for ww, ii in zip(w, idx):
                expert_up = w1[ii]

266
                out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
267
268
269
                out = act(out)

                expert_down = w2[ii]
270
271
272
                current_state = fused_mul_mat_gguf(
                    out, expert_down, qweight_type2
                ).mul_(ww)
273
274
275
276
277
278
279
280
                if current_hidden_state is None:
                    current_hidden_state = current_state
                else:
                    current_hidden_state.add_(current_state)
            out_hidden_states[tok] = current_hidden_state
    return out_hidden_states


281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def _fused_moe_gguf_fake(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
    activation: str,
) -> torch.Tensor:
    return torch.empty_like(x)


try:
    direct_register_custom_op(
        op_name="_fused_moe_gguf",
        op_func=_fused_moe_gguf,
        fake_impl=_fused_moe_gguf_fake,
    )
    fused_moe_gguf = torch.ops.vllm._fused_moe_gguf

except AttributeError as error:
    raise error


def _apply_gguf_embedding(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
311
    dtype: torch.dtype | None = None,
312
313
314
315
316
317
) -> torch.Tensor:
    if qweight_type in UNQUANTIZED_TYPES:
        return torch.embedding(qweight, x)
    elif qweight_type in DEQUANT_TYPES:
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        x_flat = x.flatten()
318
        assert hidden_size == qweight.shape[1] // type_size * block_size
319
        quant = torch.index_select(qweight, dim=0, index=x_flat)
320
321
322
        dequant = ops.ggml_dequantize(
            quant, qweight_type, hidden_size, x_flat.shape[0], dtype
        )
323
324
325
        return dequant.view(*x.shape, hidden_size)
    else:
        qweight_type = WeightType(qweight_type)
326
        raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
327
328
329
330
331
332
333


def _apply_gguf_embedding_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
334
    dtype: torch.dtype | None = None,
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
) -> torch.Tensor:
    return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)


try:
    direct_register_custom_op(
        op_name="_apply_gguf_embedding",
        op_func=_apply_gguf_embedding,
        fake_impl=_apply_gguf_embedding_fake,
    )
    apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

except AttributeError as error:
    raise error


351
352
353
354
355
356
357
358
359
360
class GGUFLinearMethod(LinearMethodBase):
    """Linear method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def __init__(self, quant_config: GGUFConfig):
        self.quant_config = quant_config

361
362
363
364
365
366
367
368
369
370
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
371
        self.params_dtype = params_dtype
372
373
374
        output_size_per_partition = sum(output_partition_sizes)

        tensor_shape = (output_size_per_partition, input_size_per_partition)
375
        qweight = GGUFUninitializedParameter(requires_grad=False)
376
        set_weight_attrs(
377
378
            qweight,
            {
379
380
381
382
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
383
                "data_container": [],
384
                "shard_id": [],
385
                "shard_id_map": {},
386
387
            },
        )
388
389
390
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("qweight", qweight)

391
392
393
394
        qweight_type = Parameter(
            torch.empty(len(output_partition_sizes), dtype=torch.uint8),
            requires_grad=False,
        )
395
        set_weight_attrs(
396
397
            qweight_type,
            {
398
399
400
                "is_gguf_weight_type": True,
                "weight_type": 0,
                "shard_weight_type": {},
401
402
403
                "ignore_warning": True,
            },
        )
404
405
406
        set_weight_attrs(qweight_type, extra_weight_attrs)
        layer.register_parameter("qweight_type", qweight_type)

407
408
    def process_weights_after_loading(self, layer: torch.nn.Module):
        qweight_type = layer.qweight_type.weight_type
409
        if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
410
411
            qweight_type = WeightType(qweight_type)
            raise ValueError(
412
413
                f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
            )
414
415
416
417
418
419
420
421
422
423
424
425
        # For MergedColumnParallelLinear and QKVParallelLinear, we need to
        # materialize the padded weight parameter for CUDA Graph compatibility.
        self._create_padded_weight_param(layer)

    def _create_padded_weight_param(self, layer: torch.nn.Module):
        """Create padded weight parameter for GGUF MergedLinear layer."""
        qweight = layer.qweight
        shard_id_map = qweight.shard_id_map
        shard_id = qweight.shard_id
        if len(data_container := qweight.data_container) > 1:
            dtype = {data.dtype for data in data_container}
            assert len(dtype) == 1, ValueError(
426
427
                f"Data container has mixed dtypes: {dtype}"
            )
428
429
430
431
432
433
            dtype = next(iter(dtype))
            # concat dim0 and pad dim1
            padded_side = max(x.size(1) for x in data_container)
            concat_side = sum(x.size(0) for x in data_container)
            # Pad the quantized weights to dense tensor, and create a map
            # with the location of each shard in the padded tensor.
434
435
436
            padded_data = torch.zeros(
                (concat_side, padded_side), dtype=dtype, device=qweight.device
            )
437
438
439
440
            # (dim0_start, dim0_end, dim1_size)
            shard_offset_map = dict[str, tuple[int, int, int]]()
            for idx in shard_id:
                id_in_container = shard_id_map[idx]
441
                start = sum(x.size(0) for x in data_container[:id_in_container])
442
443
444
445
446
447
448
                end = start + data_container[id_in_container].size(0)
                size = data_container[id_in_container].size(1)
                padded_data[start:end, :size] = data_container[id_in_container]
                shard_offset_map[idx] = (start, end, size)
            qweight.data_container.clear()
            padded_param = Parameter(padded_data, requires_grad=False)
            set_weight_attrs(padded_param, vars(qweight))
449
            set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
450
451
            layer.register_parameter("qweight", padded_param)

452
453
454
455
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
456
        bias: torch.Tensor | None = None,
457
    ) -> torch.Tensor:
458
        shard_id = layer.qweight.shard_id
459

460
        if shard_id:
461
462
            # dequantize shard weights respectively
            shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
463
            qweight = layer.qweight
464
            result = []
465
            for idx in shard_id:
466
                start, end, offset = layer.qweight.shard_offset_map[idx]
467
                qweight_type = layer.qweight_type.shard_weight_type[idx]
468
469
                result.append(
                    fused_mul_mat_gguf(
470
471
472
                        x, qweight[start:end, :offset].contiguous(), qweight_type
                    )
                )
473
474
475
476
            out = torch.cat(result, axis=1)
        else:
            qweight = layer.qweight
            qweight_type = layer.qweight_type.weight_type
477
            out = fused_mul_mat_gguf(x, qweight, qweight_type)
478
479
480
481
482
        if bias is not None:
            out.add_(bias)
        return out


483
484
485
486
487
488
489
class GGUFMoEMethod(FusedMoEMethodBase):
    """MoE method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

490
491
492
493
494
495
    def __init__(
        self,
        quant_config: GGUFConfig,
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
496
497
        self.quant_config = quant_config

498
499
500
501
502
503
504
505
506
507
508
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
        # gate up proj
509
510
        w13_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
511
512
            w13_qweight,
            {
513
514
515
516
517
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
518
519
            },
        )
520
521
522
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        layer.register_parameter("w13_qweight", w13_qweight)

523
524
525
526
527
528
529
        w13_qweight_type = Parameter(
            torch.empty(1, dtype=torch.uint8), requires_grad=False
        )
        set_weight_attrs(
            w13_qweight_type,
            {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
        )
530
531
532
        set_weight_attrs(w13_qweight_type, extra_weight_attrs)
        layer.register_parameter("w13_qweight_type", w13_qweight_type)

533
534
        tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
        # gate down proj
535
536
        w2_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
537
538
            w2_qweight,
            {
539
540
541
542
543
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
544
545
            },
        )
546
547
548
        set_weight_attrs(w2_qweight, extra_weight_attrs)
        layer.register_parameter("w2_qweight", w2_qweight)

549
550
551
552
553
554
555
        w2_qweight_type = Parameter(
            torch.empty(1, dtype=torch.uint8), requires_grad=False
        )
        set_weight_attrs(
            w2_qweight_type,
            {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
        )
556
557
558
559

        set_weight_attrs(w2_qweight_type, extra_weight_attrs)
        layer.register_parameter("w2_qweight_type", w2_qweight_type)

560
    def get_fused_moe_quant_config(
561
        self, layer: torch.nn.Module
562
    ) -> FusedMoEQuantConfig | None:
563
564
        return None

565
566
567
568
569
570
571
572
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
573
574
        topk_group: int | None = None,
        num_expert_group: int | None = None,
575
        global_num_experts: int = -1,
576
577
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
578
        scoring_func: str = "softmax",
579
        routed_scaling_factor: float = 1.0,
580
        e_score_correction_bias: torch.Tensor | None = None,
581
        apply_router_weight_on_input: bool = False,
582
        activation: str = "silu",
583
        enable_eplb: bool = False,
584
585
586
587
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
588
        if enable_eplb:
589
            raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
590

591
        assert activation == "silu", "Only SiLU activation is supported."
592
593
594
        if apply_router_weight_on_input:
            raise NotImplementedError(
                "Apply router weight on input is not supported for"
595
596
                "fused GGUF MoE method."
            )
597

XuruiYang's avatar
XuruiYang committed
598
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
599
600
601
602
603
604
605
606
607
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
608
            routed_scaling_factor=routed_scaling_factor,
609
            e_score_correction_bias=e_score_correction_bias,
610
611
612
613
614
615
616
617
618
619
620
621
            indices_type=self.topk_indices_dtype,
        )
        return fused_moe_gguf(
            x,
            layer.w13_qweight,
            layer.w2_qweight,
            topk_weights,
            topk_ids,
            layer.w13_qweight_type.weight_type,
            layer.w2_qweight_type.weight_type,
            activation,
        )
622
623


624
625
626
627
628
629
630
class GGUFEmbeddingMethod(GGUFLinearMethod):
    """Embedding method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

631
    def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
632
633
        qweight = layer.qweight
        qweight_type = layer.qweight_type.weight_type
634
        hidden_size = qweight.tensor_shape[1]
635

636
637
638
        return apply_gguf_embedding(
            x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
        )
639
640
641
642


class GGUFUninitializedParameter(UninitializedParameter):
    cls_to_become = Parameter
643
    data_container: list[torch.Tensor]