Unverified Commit 360bd67c authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Core] Support loading GGUF model (#5191)


Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
parent ef527be0
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type
import torch
from torch import nn
......@@ -23,6 +24,14 @@ class QuantizeMethodBase(ABC):
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
# Not required functions
def embedding(self, layer: torch.nn.Module, *args,
**kwargs) -> torch.Tensor:
"""Gather embeddings in the layer based on indices in the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
......@@ -31,6 +40,21 @@ class QuantizeMethodBase(ABC):
return
def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
has been changed from the base implementation.
"""
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
None)
class_embedding = inspect.getattr_static(method_class, "embedding", None)
return (class_embedding is not None
and class_embedding is not base_embedding)
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
......
from typing import Any, Dict, List, Optional
import gguf
import torch
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, ) -> None:
pass
def __repr__(self) -> str:
return ("GGUFConfig()")
def get_name(self) -> str:
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
if get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"GGUF quantization hasn't supported tensor parallelism yet.")
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# use dequantize mulmat for IQmatrix, mmq for k-quants
if qweight_type >= 16:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.T
else:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
return y
class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition)
qweight = UninitializedParameter(requires_grad=False)
set_weight_attrs(
qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"shard_size": {},
"shard_id": [],
})
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qweight", qweight)
qweight_type = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(
qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"shard_weight_type": {},
"ignore_warning": True
})
set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_size = getattr(layer.qweight, "shard_size", None)
shard_id = getattr(layer.qweight, "shard_id", None)
if shard_id and shard_size:
result = []
offset = 0
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
for id in shard_id:
shard_weight = layer.qweight[
offset:offset +
shard_size[id][0], :shard_size[id][1]].contiguous()
qweight_type = layer.qweight_type.shard_weight_type[id]
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
offset += shard_size[id][0]
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
out = _fuse_mul_mat(x, qweight, qweight_type)
if bias is not None:
out.add_(bias)
return out
class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def embedding(self, layer: torch.nn.Module,
x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
hidden_size = qweight.shape[1] // type_size * block_size
if qweight_type < 2:
return torch.embedding(qweight, x)
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0])
return dequant.view(*x.shape, hidden_size)
......@@ -3,19 +3,46 @@ from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
......@@ -199,7 +226,19 @@ class VocabParallelEmbedding(torch.nn.Module):
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
linear_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding(
type(linear_method))
if is_embedding_layer and not linear_method_implements_embedding:
raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.linear_method: QuantizeMethodBase = linear_method
if params_dtype is None:
......@@ -306,6 +345,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
# If the parameter is a gguf weight, then load it directly.
if getattr(param, "is_gguf_weight_type", None):
param.data.copy_(loaded_weight)
param.weight_type = loaded_weight.item()
return
elif isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
......@@ -344,7 +391,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input.long(), self.weight)
output_parallel = self.linear_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
......@@ -389,6 +437,7 @@ class ParallelLMHead(VocabParallelEmbedding):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
......
This diff is collapsed.
This diff is collapsed.
......@@ -238,6 +238,7 @@ class Qwen2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment