Unverified Commit b98c62ba authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix GGUF inference with FP16 unquantized checkpoint (#10675)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent c411def2
...@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional ...@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
import gguf import gguf
import torch import torch
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -49,19 +50,65 @@ class GGUFConfig(QuantizationConfig): ...@@ -49,19 +50,65 @@ class GGUFConfig(QuantizationConfig):
return None return None
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,
WeightType.Q4_1,
WeightType.Q5_0,
WeightType.Q5_1,
WeightType.Q8_0,
WeightType.Q8_1,
}
KQUANT_TYPES = {
WeightType.Q2_K,
WeightType.Q3_K,
WeightType.Q4_K,
WeightType.Q5_K,
WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
WeightType.IQ1_M,
WeightType.IQ1_S,
WeightType.IQ2_XXS,
WeightType.IQ2_XS,
WeightType.IQ2_S,
WeightType.IQ3_XXS,
WeightType.IQ3_S,
WeightType.IQ4_XS,
WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor: qweight_type: int) -> torch.Tensor:
# use dequantize mulmat for IQmatrix, mmq for k-quants # there is no need to call any kernel for fp16/bf16
if x.shape[0] == 1: if qweight_type in UNQUANTIZED_TYPES:
# enable mmvq in contiguous batching return x @ qweight.T
# enable MMVQ in contiguous batching with batch_size=1
if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
elif qweight_type >= 16: # Use MMQ Kernel if it's available (standard + k-quants)
elif qweight_type in MMQ_QUANT_TYPES:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape) weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.T y = x @ weight.T
else: else:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) # Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")
return y return y
...@@ -121,9 +168,9 @@ class GGUFLinearMethod(LinearMethodBase): ...@@ -121,9 +168,9 @@ class GGUFLinearMethod(LinearMethodBase):
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight.unbind(0) qweight = layer.qweight.unbind(0)
result = [] result = []
for id in shard_id: for idx in shard_id:
q_idx = layer.qweight.shard_id_map[id] q_idx = layer.qweight.shard_id_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[id] qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
out = torch.cat(result, axis=1) out = torch.cat(result, axis=1)
else: else:
...@@ -163,9 +210,13 @@ class GGUFUninitializedParameter(UninitializedParameter): ...@@ -163,9 +210,13 @@ class GGUFUninitializedParameter(UninitializedParameter):
data_container: List[torch.Tensor] data_container: List[torch.Tensor]
def materialize_nested(self) -> Parameter: def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
nested_data = torch.nested.nested_tensor(self.data_container, nested_data = torch.nested.nested_tensor(self.data_container,
device=self.device, device=self.device,
dtype=torch.uint8) dtype=dtype)
self.data_container.clear() self.data_container.clear()
param = torch.Tensor._make_subclass(self.cls_to_become, param = torch.Tensor._make_subclass(self.cls_to_become,
nested_data, nested_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment