Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
......@@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
from pydantic import BaseModel, Field
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
class CompressionFormat(Enum):
dense = "dense"
......@@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
......@@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
......
......@@ -9,8 +9,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
......@@ -18,14 +21,6 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
......@@ -62,37 +57,10 @@ class FBGEMMFp8Config(QuantizationConfig):
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignore_list
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in self.ignore_list
assert is_skipped is not None
return is_skipped
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
if is_layer_skipped(prefix, self.ignore_list):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None
......@@ -105,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
......@@ -172,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition,
bias=bias)
return apply_fp8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
cutlass_fp8_supported=True,
use_per_token_if_dynamic=True)
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
......@@ -6,17 +6,20 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
fused_moe)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
......@@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
......@@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []
@classmethod
def get_name(cls) -> str:
......@@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
activation_scheme=activation_scheme,
ignored_layers=ignored_layers)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
......@@ -170,19 +179,29 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
# Dequant -> Quant with max scale.
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
......@@ -384,6 +403,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_moe
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
......
......@@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_gptq_marlin_supported, verify_marlin_supports_shape)
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
......@@ -22,6 +23,12 @@ logger = init_logger(__name__)
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool) -> None:
if desc_act and group_size == -1:
......@@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.pack_factor = 32 // self.weight_bits # packed into int32
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
# Verify supported on platform.
verify_gptq_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size,
is_sym=self.is_sym)
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")
......@@ -79,7 +89,8 @@ class GPTQMarlinConfig(QuantizationConfig):
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "gptq_marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
......@@ -121,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
or desc_act is None):
return False
return check_gptq_marlin_supported(
num_bits=num_bits,
group_size=group_size,
is_sym=sym,
min_capability=cls.get_min_capability())
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size,
min_capability=cls.get_min_capability())
class GPTQMarlinLinearMethod(LinearMethodBase):
......@@ -292,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format.
......@@ -318,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
......
......@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
......@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
scalar_types.uint4b8, scalar_types.uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
class GPTQMarlin24Config(QuantizationConfig):
......@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits: int,
group_size: int,
) -> None:
self.weight_bits = weight_bits
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}.get(weight_bits)
self.group_size = group_size
# Verify
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
if quant_type is None or \
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
raise ValueError(
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
f"Marlin_24 does not support quant_type = {quant_type}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
......@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
self.quant_type = quant_type
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits
self.pack_factor = 32 // self.quant_type.size_bits
# Tile size used by marlin kernels.
self.tile_size = 16
......@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
self.perm_len = 1024
def __repr__(self) -> str:
return "Marlin24Config(weight_bits={}, group_size={})".format(
self.weight_bits, self.group_size)
return "Marlin24Config(quant_type={}, group_size={})".format(
self.quant_type, self.group_size)
@classmethod
def get_name(cls) -> str:
......@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.weight_bits,
self.quant_config.quant_type,
size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
......
......@@ -46,10 +46,8 @@ class BaseKVCacheMethod(QuantizeMethodBase):
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
v_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
MARLIN_QQQ_TILE = 16
MARLIN_QQQ_MIN_THREAD_N = 64
MARLIN_QQQ_MIN_THREAD_K = 128
MARLIN_QQQ_MAX_PARALLEL = 16
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
MARLIN_QQQ_SUPPORTED_SYM = [True]
class QQQConfig(QuantizationConfig):
"""Config class for QQQ
Reference: https://arxiv.org/pdf/2406.09904
"""
def __init__(
self,
weight_bits: int,
group_size: int,
is_sym: bool = True,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.is_sym = is_sym
# Verify
if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS:
raise ValueError(
f"QQQ does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"QQQ does not support group_size = {self.group_size}. "
f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM:
raise ValueError(
f"QQQ does not support is_sym = {self.is_sym}. "
f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits
# Tile size used by QQQ kernels.
self.tile_size = MARLIN_QQQ_TILE
# Min out_features dim
self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N
# Min in_features dim
self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = MARLIN_QQQ_MAX_PARALLEL
# Permutation length used by the QQQ kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return "QQQConfig(weight_bits={}, group_size={})".format(
self.weight_bits, self.group_size)
@classmethod
def get_name(cls) -> str:
return "qqq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
return [
"quant_config.json",
"quantize_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QQQLinearMethod"]:
if isinstance(layer, LinearBase):
return QQQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class QQQLinearMethod(LinearMethodBase):
"""Linear method for QQQ.
Args:
quant_config: The QQQ quantization config.
"""
def __init__(self, quant_config: QQQConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
s_channel = Parameter(
torch.empty(
1,
output_size_per_partition,
device="cuda",
dtype=torch.float,
),
requires_grad=False,
)
set_weight_attrs(
s_channel,
{
"input_dim": None,
"output_dim": 1,
},
)
if self.quant_config.group_size == -1:
s_group = Parameter(
torch.tensor(
[],
device="cuda",
dtype=torch.half,
),
requires_grad=False,
)
else:
s_group = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=torch.half,
),
requires_grad=False,
)
set_weight_attrs(
s_group,
{
"input_dim": None if self.quant_config.group_size == -1 else 0,
"output_dim":
None if self.quant_config.group_size == -1 else 1,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s_channel", s_channel)
set_weight_attrs(s_channel, extra_weight_attrs)
layer.register_parameter("s_group", s_group)
set_weight_attrs(s_group, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B
s_ch = layer.s_channel
s_group = layer.s_group
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = s_ch.shape[1]
x_int8, s_tok = ops.scaled_int8_quant(x_2d)
output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
......@@ -5,6 +5,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols
......@@ -13,80 +14,78 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT = True
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: Optional[int],
has_zp: bool) -> Tuple[bool, Optional[str]]:
if min_capability is not None:
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < min_capability:
return (False, "Marlin does not support device_capability = {}"
", the min_capability required is {}".format(
device_capability, min_capability))
if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
return (False, "Marlin does not support weight_bits = {}. "
"Only weight_bits = {} are supported.".format(
num_bits, MARLIN_SUPPORTED_NUM_BITS))
if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (False, "Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported.".format(
group_size, MARLIN_SUPPORTED_GROUP_SIZES))
if not has_zp and not is_sym:
return (False,
"Marlin without zero_points must have symmetric quantization")
min_capability = major * 10 + minor
return True, None
if min_capability < 80:
return []
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability,
has_zp=False)
return cond
def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
False,
min_capability,
has_zp=has_zp)
return cond
if min_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
def verify_gptq_marlin_supported(num_bits: int, group_size: int,
is_sym: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability=None,
has_zp=False)
if not cond:
assert err_msg is not None
raise ValueError("GPTQ" + err_msg)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True, None
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
return cond
def verify_awq_marlin_supported(num_bits: int, group_size: int,
has_zp: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
False,
min_capability=None,
has_zp=has_zp)
def verify_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False) -> None:
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError("AWQ" + err_msg)
raise ValueError(err_msg)
def verify_marlin_supports_shape(output_size_per_partition: int,
......@@ -240,11 +239,12 @@ def apply_gptq_marlin_linear(
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
......@@ -255,12 +255,13 @@ def apply_gptq_marlin_linear(
g_idx,
g_idx_sort_indices,
workspace,
num_bits,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False)
has_zp=False,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
......@@ -276,10 +277,11 @@ def apply_awq_marlin_linear(
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
......@@ -290,12 +292,13 @@ def apply_awq_marlin_linear(
g_idx,
g_idx_sort_indices,
workspace,
num_bits,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True)
has_zp=True,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
......
......@@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
strategy: str = "tensor") -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
......@@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
is_channelwise = (len(layer.weight_scale.shape) > 0
and layer.weight_scale.shape[0] == part_size_n)
if is_channelwise:
scales = layer.weight_scale
else:
scales = layer.weight_scale.repeat(1, part_size_n)
scales = scales.to(layer.orig_dtype).to(device)
scales = layer.weight_scale.to(layer.orig_dtype)
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,
......
......@@ -5,10 +5,12 @@ from typing import List
import numpy as np
import torch
from vllm.scalar_type import ScalarType
from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
marlin_zero_points)
from .quant_utils import (get_pack_factor, quantize_weights,
quantize_weights_with_zp, sort_weights)
from .quant_utils import (get_pack_factor, gptq_quantize_weights,
quantize_weights, sort_weights)
class MarlinWorkspace:
......@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
return perm
def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
act_order: bool):
size_k, size_n = w.shape
num_bits = quant_type.size_bits
# Normalize group_size
if group_size == -1:
......@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
w, quant_type, group_size, act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
......@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
return res_list
def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType,
group_size: int):
size_k, size_n = w.shape
# Normalize group_size
......@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
num_groups = size_k // group_size
# Quantize with zp
w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size)
w_ref, q_w, s, zp = quantize_weights(w,
quant_type,
group_size,
zero_points=True)
# Reformat to marlin
weight_perm = get_weight_perm(num_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits)
marlin_zp = marlin_zero_points(zp, num_groups, size_n,
quant_type.size_bits)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
......
......@@ -6,8 +6,10 @@ from typing import List
import numpy
import torch
from vllm.scalar_type import ScalarType
from .marlin_utils_test import marlin_weights
from .quant_utils import quantize_weights
from .quant_utils import gptq_quantize_weights
# This is PyTorch implementation of main part of reorder_meta()
......@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
assert q_24.shape == (size_k, size_n)
# Remove zp to normalize over 0
max_q_val = (1 << num_bits) - 1
zp = (max_q_val + 1) // 2
q_24_no_zp = q_24 - zp
# Remove bias to normalize over 0
q_24_no_zp = q_24 - wtype.bias
# Compress
q_24_no_zp = q_24_no_zp.t().contiguous()
......@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
q_24_no_zp)
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
# Restore zp
q_24_comp = q_24_no_zp_comp + zp
# Restore bias
q_24_comp = q_24_no_zp_comp + wtype.bias
# Resize meta to its actual shape (without moving any data)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
......@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
def marlin_24_quantize(
w: torch.Tensor,
num_bits: int,
quant_type: ScalarType,
group_size: int,
):
size_k, size_n = w.shape
......@@ -441,20 +441,18 @@ def marlin_24_quantize(
w_24, mask_24 = inject_24(w, size_k, size_n)
# Quantize
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
num_bits,
group_size,
act_order=False)
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
w_24, quant_type, group_size, act_order=False)
# Compress quantized weight
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
num_bits)
quant_type)
size_k_comp = size_k // 2
# Reformat to marlin
weight_perm = get_weight_perm_24(num_bits)
weight_perm = get_weight_perm_24(quant_type.size_bits)
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
num_bits, weight_perm)
quant_type.size_bits, weight_perm)
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
# Create result
......
from typing import List
import numpy
import torch
from .marlin_utils_test import marlin_permute_weights
from .quant_utils import get_pack_factor, qqq_quantize_weights
def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
dtype=numpy.uint32)
if group_size == size_k:
for i in range(pack_factor):
q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
else:
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
return q_packed
def get_qqq_scale_perms():
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def get_qqq_weight_perm(num_bits: int, quant_type: str):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
4 * (i % 4),
4 * (i % 4) + 1,
4 * (i % 4) + 2,
4 * (i % 4) + 3,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm_list)
assert quant_type in ["per-channel",
"per-group"], "not supported quantization type"
if num_bits == 4:
if quant_type == "per-channel":
interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
else:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
else:
raise Exception("num_bits must be 4, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
scale_perm, scale_perm_single = get_qqq_scale_perms()
if group_size < size_k and group_size != -1:
s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
s_channel = s_channel.reshape(
(-1, len(scale_perm_single)))[:, scale_perm_single]
s_group = s_group.reshape((-1, size_n)).contiguous()
else:
s_channel = s_channel.reshape(
(-1, len(scale_perm_single)))[:, scale_perm_single]
s_channel = s_channel.reshape((-1, size_n)).contiguous()
return s_group, s_channel
def marlin_qqq_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
quant_type = "per-channel" if group_size == size_k else "per-group"
# Quantize
w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
w, num_bits, group_size)
# Reformat to marlin_qqq
weight_perm = get_qqq_weight_perm(num_bits, quant_type)
marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
weight_perm, group_size)
marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
s_group, s_channel, size_k, size_n, group_size)
# Create result
res_list = [
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
"""This file is used for /tests and /benchmarks"""
from typing import List
import numpy
import torch
SUPPORTED_NUM_BITS = [4, 8]
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in ignored_layers
assert is_skipped is not None
return is_skipped
def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
......@@ -36,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
)
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
act_order: bool):
def quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
zero_points: bool = False):
assert quant_type.is_integer(), \
"Floating point quantization may work but has not been tested"
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
......@@ -61,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w = w.reshape((group_size, -1))
# Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / max_q_val # 2 => symmetric
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
.clamp(min_q_val, max_q_val).int()
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
maybe_w_zp = None
# Quantize
q_w = torch.round(w / s).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size < size_k:
......@@ -81,10 +140,35 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
w_s = w_s.reshape((-1, size_n)).contiguous()
if zero_points:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s.to(device=orig_device),
maybe_w_zp,
)
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
group_size: int, act_order: bool):
size_k, _ = w.shape
assert w.is_floating_point(), "w must be float"
assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
f"Unsupported gptq type = {quant_type}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
# Apply act_order
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
......@@ -95,23 +179,20 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
return w_ref, w_q, w_s, g_idx, rand_perm
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
# QQQ employs different quant schemes for per-group and
# per-channel quantization.
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
......@@ -120,33 +201,27 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
min_q_val = 0
# Reshape to [groupsize, -1]
if group_size < size_k:
# Reshape to [groupsize, -1]
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max = torch.max(w, 0, keepdim=True)[0]
min = torch.min(w, 0, keepdim=True)[0]
s = (max - min).clamp(min=1e-5) / max_q_val
# Compute zero-point for each group
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
# Quantize
q_w = torch.round(w / s).int() + zp
q_w = torch.clamp(q_w, min_q_val, max_q_val)
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Compute ref (dequantized)
w_ref = (q_w - zp).half() * s
# Compute scale for each group
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
s_group *= 2 / max_q_val # 2 => symmetric
# Restore original shapes
if group_size < size_k:
# Quantize
q_w = torch.round(w / s_group).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s_group
# Restore original shapes
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
......@@ -156,14 +231,39 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
zp = zp.reshape((-1, size_n)).contiguous()
# Compute int8 quantization scale for each channel
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
s_channel /= 127.0
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
w_ref = t_int8.half() * s_channel
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
# Fuse scales
s_group = (s_group.reshape(-1, size_n).contiguous() /
s_channel).to(dtype=torch.half)
else:
max_q_val = 2**(num_bits - 1) - 1
# Compute scale for each channel
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
s_channel /= max_q_val
# Quantize
q_w = torch.round(w / s_channel).int()
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = q_w.half() * s_channel
s_group = torch.tensor([], dtype=torch.half)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel /= (2**(8 - num_bits))
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
zp.to(device=orig_device),
s_group.to(device=orig_device),
s_channel.to(device=orig_device),
)
......
......@@ -139,7 +139,7 @@ def apply_fp8_linear(
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
batch_dim_padding=17,
num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)
per_tensor_weights = (weight_scale.numel() == 1)
......@@ -177,8 +177,9 @@ def apply_fp8_linear(
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32)
# Unpad (undo batch_dim_padding)
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
......
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
import torch.jit
......@@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
......@@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
......@@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs,
draft_probs,
draft_token_ids,
generators,
seeded_seqs,
))
output_token_ids = self._create_output(
......@@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
......@@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, generators)
draft_token_ids, seeded_seqs)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)
return accepted, recovered_token_ids
......@@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
......@@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)
if len(seed_indices) == 0:
if not seeded_seqs:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)
for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])
if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(
1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
......@@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
return torch.finfo(self.probs_dtype).tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:
if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))
return seed_indices, non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
......@@ -304,9 +282,7 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
seeded_seqs: Dict[int, torch.Generator],
) -> torch.Tensor:
if num_samples > 1:
......@@ -315,13 +291,20 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs)
if len(seed_indices) == 0:
if not seeded_seqs:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
......@@ -774,6 +774,7 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
rotary_percent: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
......@@ -786,6 +787,8 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if rotary_percent < 1.0:
rotary_dim = int(rotary_dim * rotary_percent)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
if key in _ROPE_DICT:
......
"""A layer that samples the next tokens from the model's outputs."""
import itertools
from math import inf
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
......@@ -220,7 +225,7 @@ def _apply_min_tokens_penalty(
seqs_to_penalize: List[int] = []
for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
if len(seq_data.output_token_ids_array) < min_tokens:
seqs_to_penalize.append(j)
if seqs_to_penalize:
......@@ -774,8 +779,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from.
next_token_ids: List[int] = []
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs = 1
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
......@@ -808,6 +816,8 @@ def _get_logprobs(
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
use_beam_search = use_beam_search or sampling_params.use_beam_search
assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0:
......@@ -815,35 +825,40 @@ def _get_logprobs(
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
else:
top_logprobs, top_token_ids = None, None
selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]
selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
# We need to compute top k only if there exists logprobs > 0.
if largest_num_logprobs > 0:
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
......@@ -940,46 +955,53 @@ def _get_sampled_logprob_if_needed(
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs or 0
num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample:
assert len(next_token_ids) > 0
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id,
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx])
}
# Get top K logprobs.
if num_logprobs > 0:
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
if num_logprobs is None and not use_beam_search:
for next_token_id in next_token_ids:
# Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(inf)})
else:
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id, parent_id) in enumerate(
zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id:
(selected_logprob_items[idx], rank_items[idx])
}
if num_logprobs is not None and num_logprobs > 0:
# Get top K logprobs.
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[
top_logprob_idx + parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(
top_ids, top_probs, top_ranks)
})
sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})
sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
......
from abc import abstractmethod
from typing import List, Optional
from typing import Dict, Optional
import torch
import torch.jit
......@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
raise NotImplementedError
......@@ -7,6 +7,7 @@ import json
import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import huggingface_hub
......@@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_tpu
from vllm.utils import is_pin_memory_available, is_tpu
@contextmanager
def device_loading_context(module: torch.nn.Module,
target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: Dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
logger = init_logger(__name__)
......@@ -164,7 +207,7 @@ class DefaultModelLoader(BaseModelLoader):
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
ignore_patterns=self.load_config.ignore_patterns,
ignore_file_pattern=self.load_config.ignore_patterns,
)
else:
model_path = model
......@@ -278,8 +321,9 @@ class DefaultModelLoader(BaseModelLoader):
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config, scheduler_config)
......@@ -294,7 +338,13 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None and quant_method!="awq" and quant_method!="gptq":
quant_method.process_weights_after_loading(module)
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()
......@@ -705,8 +755,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return hf_weights_files, matched_pattern == "*.safetensors"
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)
def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str]
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
......@@ -715,6 +771,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
try:
import bitsandbytes
from bitsandbytes.functional import QuantState
if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
......@@ -728,17 +785,63 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path, revision)
quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weight_iterator = pt_weights_iterator(hf_weights_files)
def generator():
def quantized_checkpoint() -> Generator:
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# TODO: only nf4 quantization is supported for now
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
raise NotImplementedError(
"Only bitsandbytes_nf4 quantization"
f"is supported for now. {weight_name} is fp4 quantized"
)
temp_state_dict[weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state[param_name +
".quant_state.bitsandbytes__nf4"] = quant_state[
param_name +
".quant_state.bitsandbytes__nf4"].cpu().data
return QuantState.from_dict(quant_state, device="cuda")
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if weight_name + ".quant_state.bitsandbytes__nf4" \
in temp_state_dict:
quant_state = _parse_quant_state(weight_name,
temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight",
".qweight"), weight_tensor
else:
yield weight_name, weight_tensor
def generator() -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# bitsandbytes requires data in GPU
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
......@@ -752,6 +855,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield weight_name, processed_weight
if pre_quant:
return quantized_checkpoint(), quant_state_dict
return generator(), quant_state_dict
def _load_weights(self, model_config: ModelConfig,
......@@ -769,12 +874,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision))
is_quantized_checkpoint = False
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if quant_config is not None and quant_config.get(
'quant_method') == "bitsandbytes":
is_quantized_checkpoint = True
qweight_iterator, quant_state_dict = \
self._get_quantized_weights_iterator(
model_config.model, model_config.revision, is_quantized_checkpoint)
model.load_weights(qweight_iterator)
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict:
......@@ -812,9 +926,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in enumerate(quant_states.items()):
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(
quant_state[1].shape) // pack_ratio
quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
......
......@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
logger = init_logger(__name__)
......@@ -118,6 +119,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
......@@ -489,6 +491,11 @@ def initialize_dummy_weights(
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
if current_platform.is_tpu():
# XLA device does not support torch.Generator()
param.uniform_(low, high)
continue
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment