gguf.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Callable, Optional, Union
5
6
7

import gguf
import torch
8
from gguf import GGMLQuantizationType as WeightType
9
10
11
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
12
from vllm.logger import init_logger
13
14
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
                                                         FusedMoEQuantConfig)
15
16
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
                                                        FusedMoEMethodBase)
17
18
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
19
from vllm.model_executor.layers.quantization import QuantizationMethods
20
21
22
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
25
from vllm.utils import direct_register_custom_op
26

27
28
logger = init_logger(__name__)

29
30
31
32

class GGUFConfig(QuantizationConfig):
    """Config class for GGUF."""

33
34
    def __init__(self,
                 unquantized_modules: Optional[list[str]] = None) -> None:
35
        super().__init__()
36
        self.unquantized_modules = unquantized_modules or []
37
38
39
40

    def __repr__(self) -> str:
        return ("GGUFConfig()")

41
    def get_name(self) -> QuantizationMethods:
42
43
        return "gguf"

44
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
45
        return [torch.half, torch.bfloat16, torch.float32]
46
47
48
49
50
51

    @classmethod
    def get_min_capability(cls) -> int:
        return 60

    @classmethod
52
    def get_config_filenames(cls) -> list[str]:
53
54
55
        return []  # no extra configs.

    @classmethod
56
    def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
57
58
59
60
61
        return cls()

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        if isinstance(layer, LinearBase):
62
63
            if is_layer_skipped_gguf(prefix, self.unquantized_modules):
                return UnquantizedLinearMethod()
64
65
66
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
            return GGUFEmbeddingMethod(self)
67
        elif isinstance(layer, FusedMoE):
68
            return GGUFMoEMethod(self, layer.moe_config)
69
70
71
        return None


72
73
74
75
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
    return any(module_name in prefix for module_name in unquantized_modules)


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
    WeightType.Q4_0,
    WeightType.Q4_1,
    WeightType.Q5_0,
    WeightType.Q5_1,
    WeightType.Q8_0,
    WeightType.Q8_1,
}
KQUANT_TYPES = {
    WeightType.Q2_K,
    WeightType.Q3_K,
    WeightType.Q4_K,
    WeightType.Q5_K,
    WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
    WeightType.IQ1_M,
    WeightType.IQ1_S,
    WeightType.IQ2_XXS,
    WeightType.IQ2_XS,
    WeightType.IQ2_S,
    WeightType.IQ3_XXS,
    WeightType.IQ3_S,
    WeightType.IQ4_XS,
    WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


111
112
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
                        qweight_type: int) -> torch.Tensor:
113
114
115
116
    if qweight_type in IMATRIX_QUANT_TYPES:
        mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
    else:
        mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
117
118
119
120
121
122
123
    # HACK: when doing chunked prefill we don't generate output tokens
    # so input to logits generator is empty which causes invalid parameter
    if x.shape[0] == 0:
        return torch.empty(x.shape[0],
                           qweight.shape[0],
                           dtype=x.dtype,
                           device=x.device)
124
125
126
127
    # there is no need to call any kernel for fp16/bf16
    if qweight_type in UNQUANTIZED_TYPES:
        return x @ qweight.T
    # enable MMVQ in contiguous batching with batch_size=1
128
    if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
129
        y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
130
131
132
133
134
    # Use MMQ Kernel if it's available (standard + k-quants)
    elif qweight_type in MMQ_QUANT_TYPES:
        y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
    # If there is no available MMQ kernel, fallback to dequantize
    elif qweight_type in DEQUANT_TYPES:
135
136
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
137
        weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
138
139
        y = x @ weight.T
    else:
140
141
142
143
144
145
        # Raise an error if the quantization type is not supported.
        # Might be useful if llama.cpp adds a new quantization type.
        # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
        qweight_type = WeightType(qweight_type)
        raise NotImplementedError(
            f"Unsupported GGUF quantization type: {qweight_type}")
146
147
148
    return y


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def _fused_mul_mat_gguf_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
) -> torch.Tensor:
    return torch.empty(x.shape[0],
                       qweight.shape[0],
                       dtype=x.dtype,
                       device=x.device)


try:
    direct_register_custom_op(
        op_name="_fused_mul_mat_gguf",
        op_func=_fused_mul_mat_gguf,
        mutates_args=[],
        fake_impl=_fused_mul_mat_gguf_fake,
    )
    fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf

except AttributeError as error:
    raise error


173
174
175
176
177
178
179
180
def _fused_moe_gguf(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
181
    activation: str,
182
) -> torch.Tensor:
183
184
185
186
187
188
189
190
191
192
193
194
195

    def act(x: torch.Tensor):
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if activation == "silu":
            torch.ops._C.silu_and_mul(out, x)
        elif activation == "gelu":
            torch.ops._C.gelu_and_mul(out, x)
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        return out

196
197
198
199
    # lazy import to avoid triggering triton import in CPU backend
    from vllm.model_executor.layers.fused_moe.fused_moe import (
        moe_align_block_size)

200
    out_hidden_states = torch.empty_like(x)
201
202
203
    # unless we decent expert reuse we are better off running moe_vec kernel
    if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES
            and x.shape[0] > 64):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]
        BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)

        sorted_token_ids, expert_ids, num_tokens_post_padded = \
                moe_align_block_size(topk_ids, BLOCK_SIZE, E)
        out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
                              num_tokens_post_padded, qweight_type, N, top_k,
                              num_tokens)
        out = act(out)
        out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
                              num_tokens_post_padded, qweight_type2,
                              w2.shape[1], 1, num_tokens * top_k)
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
            topk_weights.view(num_tokens, top_k, 1))
        ops.moe_sum(out, out_hidden_states)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
        num_tokens, _ = x.shape
        E, N, _ = w1.shape
        top_k = topk_ids.shape[1]

        out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N,
                                  num_tokens)
        out = act(out)

        out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2,
                                  w2.shape[1], num_tokens * top_k)
        out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
            topk_weights.view(num_tokens, top_k, 1))
        ops.moe_sum(out, out_hidden_states)
235
236
237
238
239
240
241
242
243
244
    else:
        logger.warning_once("There is no support for fast MoE kernel "
                            "for current quantization method. "
                            "Falling back to slow implementation. ")
        for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
            inp = x[tok].reshape((1, ) + x.shape[1:])
            current_hidden_state = None
            for ww, ii in zip(w, idx):
                expert_up = w1[ii]

245
                out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
246
247
248
                out = act(out)

                expert_down = w2[ii]
249
250
                current_state = fused_mul_mat_gguf(out, expert_down,
                                                   qweight_type2).mul_(ww)
251
252
253
254
255
256
257
258
                if current_hidden_state is None:
                    current_hidden_state = current_state
                else:
                    current_hidden_state.add_(current_state)
            out_hidden_states[tok] = current_hidden_state
    return out_hidden_states


259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def _fused_moe_gguf_fake(
    x: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    qweight_type: int,
    qweight_type2: int,
    activation: str,
) -> torch.Tensor:
    return torch.empty_like(x)


try:
    direct_register_custom_op(
        op_name="_fused_moe_gguf",
        op_func=_fused_moe_gguf,
        mutates_args=[],
        fake_impl=_fused_moe_gguf_fake,
    )
    fused_moe_gguf = torch.ops.vllm._fused_moe_gguf

except AttributeError as error:
    raise error


def _apply_gguf_embedding(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    if qweight_type in UNQUANTIZED_TYPES:
        return torch.embedding(qweight, x)
    elif qweight_type in DEQUANT_TYPES:
        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
        x_flat = x.flatten()
        assert (hidden_size == qweight.shape[1] // type_size * block_size)
        quant = torch.index_select(qweight, dim=0, index=x_flat)
        dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
                                      x_flat.shape[0], dtype)
        return dequant.view(*x.shape, hidden_size)
    else:
        qweight_type = WeightType(qweight_type)
        raise NotImplementedError(
            f"Unsupported GGUF quantization type: {qweight_type}")


def _apply_gguf_embedding_fake(
    x: torch.Tensor,
    qweight: torch.Tensor,
    qweight_type: int,
    hidden_size: int,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)


try:
    direct_register_custom_op(
        op_name="_apply_gguf_embedding",
        op_func=_apply_gguf_embedding,
        mutates_args=[],
        fake_impl=_apply_gguf_embedding_fake,
    )
    apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

except AttributeError as error:
    raise error


331
332
333
334
335
336
337
338
339
340
341
342
class GGUFLinearMethod(LinearMethodBase):
    """Linear method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def __init__(self, quant_config: GGUFConfig):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
343
                       output_partition_sizes: list[int], input_size: int,
344
345
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
346
        self.params_dtype = params_dtype
347
348
349
        output_size_per_partition = sum(output_partition_sizes)

        tensor_shape = (output_size_per_partition, input_size_per_partition)
350
        qweight = GGUFUninitializedParameter(requires_grad=False)
351
352
353
354
355
356
        set_weight_attrs(
            qweight, {
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
357
                "data_container": [],
358
                "shard_id": [],
359
                "shard_id_map": {},
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            })
        set_weight_attrs(qweight, extra_weight_attrs)
        layer.register_parameter("qweight", qweight)

        qweight_type = Parameter(torch.empty(len(output_partition_sizes),
                                             dtype=torch.uint8),
                                 requires_grad=False)
        set_weight_attrs(
            qweight_type, {
                "is_gguf_weight_type": True,
                "weight_type": 0,
                "shard_weight_type": {},
                "ignore_warning": True
            })
        set_weight_attrs(qweight_type, extra_weight_attrs)
        layer.register_parameter("qweight_type", qweight_type)

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    def process_weights_after_loading(self, layer: torch.nn.Module):
        qweight_type = layer.qweight_type.weight_type
        if not (qweight_type in UNQUANTIZED_TYPES
                or qweight_type in DEQUANT_TYPES):
            qweight_type = WeightType(qweight_type)
            raise ValueError(
                f"Unsupported GGUF quantization type {qweight_type} in "
                f"layer {layer}.")
        # For MergedColumnParallelLinear and QKVParallelLinear, we need to
        # materialize the padded weight parameter for CUDA Graph compatibility.
        self._create_padded_weight_param(layer)

    def _create_padded_weight_param(self, layer: torch.nn.Module):
        """Create padded weight parameter for GGUF MergedLinear layer."""
        qweight = layer.qweight
        shard_id_map = qweight.shard_id_map
        shard_id = qweight.shard_id
        if len(data_container := qweight.data_container) > 1:
            dtype = {data.dtype for data in data_container}
            assert len(dtype) == 1, ValueError(
                f"Data container has mixed dtypes: {dtype}")
            dtype = next(iter(dtype))
            # concat dim0 and pad dim1
            padded_side = max(x.size(1) for x in data_container)
            concat_side = sum(x.size(0) for x in data_container)
            # Pad the quantized weights to dense tensor, and create a map
            # with the location of each shard in the padded tensor.
            padded_data = torch.zeros((concat_side, padded_side),
                                      dtype=dtype,
                                      device=qweight.device)
            # (dim0_start, dim0_end, dim1_size)
            shard_offset_map = dict[str, tuple[int, int, int]]()
            for idx in shard_id:
                id_in_container = shard_id_map[idx]
                start = sum(
                    x.size(0) for x in data_container[:id_in_container])
                end = start + data_container[id_in_container].size(0)
                size = data_container[id_in_container].size(1)
                padded_data[start:end, :size] = data_container[id_in_container]
                shard_offset_map[idx] = (start, end, size)
            qweight.data_container.clear()
            padded_param = Parameter(padded_data, requires_grad=False)
            set_weight_attrs(padded_param, vars(qweight))
            set_weight_attrs(padded_param,
                             {"shard_offset_map": shard_offset_map})
            layer.register_parameter("qweight", padded_param)

424
425
426
427
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
428
        shard_id = layer.qweight.shard_id
429

430
        if shard_id:
431
432
            # dequantize shard weights respectively
            shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
433
            qweight = layer.qweight
434
            result = []
435
            for idx in shard_id:
436
                start, end, offset = layer.qweight.shard_offset_map[idx]
437
                qweight_type = layer.qweight_type.shard_weight_type[idx]
438
439
440
441
                result.append(
                    fused_mul_mat_gguf(
                        x, qweight[start:end, :offset].contiguous(),
                        qweight_type))
442
443
444
445
            out = torch.cat(result, axis=1)
        else:
            qweight = layer.qweight
            qweight_type = layer.qweight_type.weight_type
446
            out = fused_mul_mat_gguf(x, qweight, qweight_type)
447
448
449
450
451
        if bias is not None:
            out.add_(bias)
        return out


452
453
454
455
456
457
458
class GGUFMoEMethod(FusedMoEMethodBase):
    """MoE method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

459
460
461
462
463
464
    def __init__(
        self,
        quant_config: GGUFConfig,
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
                        hidden_size)
        #gate up proj
        w13_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
            w13_qweight, {
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
            })
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        layer.register_parameter("w13_qweight", w13_qweight)

        w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
                                     requires_grad=False)
        set_weight_attrs(w13_qweight_type, {
            "is_gguf_weight_type": True,
            "weight_type": 0,
            "ignore_warning": True
        })
        set_weight_attrs(w13_qweight_type, extra_weight_attrs)
        layer.register_parameter("w13_qweight_type", w13_qweight_type)

        tensor_shape = (num_experts, intermediate_size_per_partition,
                        hidden_size)
        #gate down proj
        w2_qweight = GGUFUninitializedParameter(requires_grad=False)
        set_weight_attrs(
            w2_qweight, {
                "input_dim": 1,
                "output_dim": 0,
                "tensor_shape": tensor_shape,
                "is_gguf_weight": True,
                "data_container": [],
            })
        set_weight_attrs(w2_qweight, extra_weight_attrs)
        layer.register_parameter("w2_qweight", w2_qweight)

        w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
                                    requires_grad=False)
        set_weight_attrs(w2_qweight_type, {
            "is_gguf_weight_type": True,
            "weight_type": 0,
            "ignore_warning": True
        })

        set_weight_attrs(w2_qweight_type, extra_weight_attrs)
        layer.register_parameter("w2_qweight_type", w2_qweight_type)

522
523
524
525
    def get_fused_moe_quant_config(
            self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
        return None

526
527
528
529
530
531
532
533
534
535
536
537
538
539
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
540
        routed_scaling_factor: float = 1.0,
541
        e_score_correction_bias: Optional[torch.Tensor] = None,
542
        apply_router_weight_on_input: bool = False,
543
        activation: str = "silu",
544
545
546
547
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
548
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
549
550
        assert self.fused_experts is None

551
552
553
554
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `GGUFMoEMethod` yet.")

555
        assert activation == "silu", "Only SiLU activation is supported."
556
557
558
559
560
        if apply_router_weight_on_input:
            raise NotImplementedError(
                "Apply router weight on input is not supported for"
                "fused GGUF MoE method.")

561
562
563
564
565
566
567
568
569
570
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
571
            routed_scaling_factor=routed_scaling_factor,
572
573
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)
574
575
576
577
        return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
                              topk_weights, topk_ids,
                              layer.w13_qweight_type.weight_type,
                              layer.w2_qweight_type.weight_type, activation)
578
579


580
581
582
583
584
585
586
587
588
589
590
class GGUFEmbeddingMethod(GGUFLinearMethod):
    """Embedding method for GGUF.

    Args:
        quant_config: The GGUF quantization config.
    """

    def embedding(self, layer: torch.nn.Module,
                  x: torch.Tensor) -> torch.Tensor:
        qweight = layer.qweight
        qweight_type = layer.qweight_type.weight_type
591
        hidden_size = qweight.tensor_shape[1]
592

593
594
595
596
597
        return apply_gguf_embedding(x,
                                    qweight,
                                    qweight_type,
                                    hidden_size,
                                    dtype=self.params_dtype)
598
599
600
601


class GGUFUninitializedParameter(UninitializedParameter):
    cls_to_become = Parameter
602
    data_container: list[torch.Tensor]