Unverified Commit 9f771b3a authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Quantization] add humming quantization kernel (#34556)

parent c9d3c6e6
......@@ -953,8 +953,12 @@ class ModelConfig:
"mxfp4",
"gpt_oss_mxfp4",
"cpu_awq",
"humming",
"gguf",
]
# if the user specifies humming, we should always use humming
if self.quantization == "humming":
overrides = ["humming"] + overrides
quantization_methods = [
q for q in supported_quantization if q not in overrides
]
......
......@@ -152,6 +152,10 @@ if TYPE_CHECKING:
VLLM_RAY_EXTRA_ENV_VARS_TO_COPY: str = ""
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
VLLM_HUMMING_ONLINE_QUANT_CONFIG: dict[str, Any] | None = None
VLLM_HUMMING_INPUT_QUANT_CONFIG: dict[str, Any] | None = None
VLLM_HUMMING_USE_F16_ACCUM: bool = False
VLLM_HUMMING_MOE_GEMM_TYPE: Literal["indexed", "grouped", "auto"] | None = None
VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
......@@ -285,6 +289,15 @@ def maybe_convert_bool(value: str | None) -> bool | None:
return bool(int(value))
def maybe_convert_json_str_or_file(value: str | None) -> dict[str, Any] | None:
if value is None:
return None
if os.path.exists(value):
with open(value) as f:
return json.load(f)
return json.loads(value)
def disable_compile_cache() -> bool:
return bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0")))
......@@ -1193,6 +1206,25 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
),
# The online quantization dtype for humming kernel
"VLLM_HUMMING_ONLINE_QUANT_CONFIG": lambda: maybe_convert_json_str_or_file(
os.environ.get("VLLM_HUMMING_ONLINE_QUANT_CONFIG", None)
),
# The activation dtype config for humming kernel
"VLLM_HUMMING_INPUT_QUANT_CONFIG": lambda: maybe_convert_json_str_or_file(
os.environ.get("VLLM_HUMMING_INPUT_QUANT_CONFIG", None)
),
# Whether to use fp16 accumulator mma
"VLLM_HUMMING_USE_F16_ACCUM": lambda: maybe_convert_bool(
os.environ.get("VLLM_HUMMING_USE_F16_ACCUM", "0")
),
# Whether to use indexed gemm for humming moe
# if 1, force use indexed gemm
# if 0, force use grouped gemm
# if None, choose better gemm type automatically
"VLLM_HUMMING_MOE_GEMM_TYPE": lambda: maybe_convert_bool(
os.environ.get("VLLM_HUMMING_MOE_GEMM_TYPE", None)
),
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# only supported on Blackwell GPUs and with
# https://github.com/deepseek-ai/DeepEP/pull/341
......
This diff is collapsed.
......@@ -1097,7 +1097,11 @@ class FusedMoE(PluggableLayer):
expert_id: int,
return_success: bool = False,
) -> bool | None:
if self.quant_config and self.quant_config.get_name() == "gpt_oss_mxfp4":
quant_config_name = self.quant_config and self.quant_config.get_name()
if quant_config_name == "humming":
assert hasattr(self.quant_method, "weight_schema")
quant_config_name = self.quant_method.weight_schema.quant_method
if quant_config_name == "gpt_oss_mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch._subclasses.fake_tensor import FakeTensor
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@triton.jit
def moe_fused_mul_sum_kernel(
inputs_ptr,
topk_weights_ptr,
outputs_ptr,
top_ids_ptr,
expert_map_ptr,
num_tokens,
stride_m,
has_expert_map: tl.constexpr,
top_k: tl.constexpr,
size: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_k = tl.program_id(0)
pid_m = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
m_mask = offs_m < num_tokens
k_mask = offs_k < size
mask = m_mask[:, None] & k_mask[None, :]
a_base = inputs_ptr + (offs_m * stride_m)[:, None] + offs_k[None, :]
b_base = topk_weights_ptr + offs_m * top_k
acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
for n in tl.static_range(top_k):
b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32)
if has_expert_map:
id_val = tl.load(top_ids_ptr + offs_m * top_k + n, mask=m_mask, other=0)
expert_mask = tl.load(expert_map_ptr + id_val) >= 0
a_vec = tl.load(
a_base + n * size,
mask=mask & expert_mask[:, None],
other=0.0,
).to(tl.float32)
else:
a_vec = tl.load(
a_base + n * size,
mask=mask,
other=0.0,
).to(tl.float32)
acc += a_vec * b_val[:, None]
out_ptrs = outputs_ptr + (offs_m * size)[:, None] + offs_k[None, :]
tl.store(
out_ptrs,
acc.to(outputs_ptr.dtype.element_ty),
mask=mask,
)
def _heuristic_config(
num_tokens: int,
top_k: int,
size: int,
element_size: int,
):
is_fp32 = element_size > 2
is_sm90_plus = current_platform.has_device_capability(90)
is_sm80_before = not current_platform.has_device_capability(80)
if current_platform.has_device_capability(90):
# SM90/SM100+: prefer small tiles + many CTAs.
if is_fp32:
BLOCK_M = 1 if num_tokens <= 4 else 2
else:
if num_tokens <= 4:
BLOCK_M = 1
elif num_tokens <= 128:
BLOCK_M = 2
else:
BLOCK_M = 4
elif is_fp32:
if num_tokens <= 4:
BLOCK_M = 1
elif num_tokens <= 32:
BLOCK_M = 2
elif num_tokens <= 128:
BLOCK_M = 4
else:
BLOCK_M = 4
else:
if num_tokens <= 4:
BLOCK_M = 1
elif num_tokens <= 32:
BLOCK_M = 2
elif num_tokens <= 128:
BLOCK_M = 4
elif num_tokens <= 1024:
BLOCK_M = 16
else:
BLOCK_M = 8
if is_fp32:
max_block_k = 256
elif is_sm80_before or is_sm90_plus:
max_block_k = 512
else:
max_block_k = 1024
BLOCK_K = min(triton.next_power_of_2(size), max_block_k)
BLOCK_K = max(BLOCK_K, 256)
total = BLOCK_M * BLOCK_K
if is_fp32:
num_warps = max(8, min(16, total // 64))
else:
num_warps = max(4, min(16, total // 256))
if is_sm80_before:
num_warps = min(num_warps, 8)
num_stages = 2
elif is_sm90_plus:
num_warps = min(num_warps, 8)
num_stages = 4 if total <= 2048 else 2
else:
num_stages = 4 if total <= 2048 else 2
return BLOCK_M, BLOCK_K, num_warps, num_stages
def moe_fused_mul_sum(
inputs: torch.Tensor,
topk_weights: torch.Tensor,
outputs: torch.Tensor | None = None,
topk_ids: torch.Tensor | None = None,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Fused kernel for MoE (Mixture of Experts) to perform weighted summation
of expert outputs.
Args:
inputs: The output from experts.
Shape: (num_tokens, top_k, hidden_size).
topk_weights: The weights assigned to each expert for each token.
Shape: (num_tokens, top_k).
outputs: Optional pre-allocated output tensor.
Shape: (num_tokens, hidden_size).
topk_ids: Optional indices of the top-k experts. Used when
`expert_map` is provided. Shape: (num_tokens, top_k).
expert_map: Optional mapping for Expert Parallelism. A value < 0
indicates an invalid token/expert pair that will be skipped.
Returns:
The fused weighted sum of expert outputs.
Shape: (num_tokens, hidden_size).
"""
assert inputs.ndim == 3
assert topk_weights.ndim == 2
assert inputs.is_contiguous()
assert topk_weights.is_contiguous()
assert inputs.dtype in (torch.float32, torch.float16, torch.bfloat16)
assert topk_weights.dtype in (torch.float32, torch.float16, torch.bfloat16)
num_tokens, top_k, size = inputs.shape
output_shape = (num_tokens, size)
if outputs is None:
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
assert outputs.shape == output_shape
assert topk_weights.shape == (num_tokens, top_k)
if not isinstance(inputs, FakeTensor):
BLOCK_M, BLOCK_K, num_warps, num_stages = _heuristic_config(
num_tokens,
top_k,
size,
inputs.element_size(),
)
grid = (triton.cdiv(size, BLOCK_K), triton.cdiv(num_tokens, BLOCK_M))
moe_fused_mul_sum_kernel[grid](
inputs,
topk_weights,
outputs,
topk_ids,
expert_map,
num_tokens,
top_k * size,
expert_map is not None,
top_k,
size,
BLOCK_M,
BLOCK_K,
num_warps=num_warps,
num_stages=num_stages,
)
return outputs
......@@ -60,6 +60,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"ModelOptFp8PbWoLinearMethod",
"QuarkLinearMethod",
"ModelOptNvFp4LinearMethod",
"HummingLinearMethod",
]
......@@ -245,6 +246,7 @@ class LinearBase(PluggableLayer):
self,
input_size: int,
output_size: int,
bias: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
......@@ -258,6 +260,7 @@ class LinearBase(PluggableLayer):
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.has_bias = bias
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
......@@ -323,6 +326,7 @@ class ReplicatedLinear(LinearBase):
super().__init__(
input_size,
output_size,
bias,
skip_bias_add,
params_dtype,
quant_config,
......@@ -458,6 +462,7 @@ class ColumnParallelLinear(LinearBase):
super().__init__(
input_size,
output_size,
bias,
skip_bias_add,
params_dtype,
quant_config,
......@@ -483,6 +488,7 @@ class ColumnParallelLinear(LinearBase):
else self.weight_loader
),
)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
......@@ -817,8 +823,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
shard_size = round(shard_size // param.packed_factor)
shard_offset = round(shard_offset // param.packed_factor)
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset
......@@ -1252,8 +1258,8 @@ class QKVParallelLinear(ColumnParallelLinear):
)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
shard_size = round(shard_size // param.packed_factor)
shard_offset = round(shard_offset // param.packed_factor)
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
......@@ -1315,8 +1321,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
shard_size = round(shard_size // param.packed_factor)
shard_offset = round(shard_offset // param.packed_factor)
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
......@@ -1440,6 +1446,7 @@ class RowParallelLinear(LinearBase):
super().__init__(
input_size,
output_size,
bias,
skip_bias_add,
params_dtype,
quant_config,
......
......@@ -22,6 +22,7 @@ QuantizationMethods = Literal[
"gptq_marlin",
"awq_marlin",
"gptq",
"humming",
"compressed-tensors",
"bitsandbytes",
"experts_int8",
......@@ -126,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .humming import HummingConfig
from .inc import INCConfig
from .modelopt import (
ModelOptFp8Config,
......@@ -162,6 +164,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config,
"cpu_awq": CPUAWQConfig,
"humming": HummingConfig,
"online": OnlineQuantizationConfig,
}
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
def humming_moe_align(
configs: list[int],
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(configs) > 0 and len(configs) % 3 == 0
# NOTE: we choose moe_block_size based on
# num_tokens * top_k (= topk_ids.nelement())
shape_m = topk_ids.nelement()
for i in range(len(configs) // 3):
if shape_m > configs[i * 3] and shape_m <= configs[i * 3 + 1]:
block_size = configs[i * 3 + 2]
break
else:
raise ValueError(f"Could not find a matching block_size for shape_m={shape_m}")
return moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
pad_sorted_ids=False,
ignore_invalid_experts=True,
)
......@@ -605,8 +605,8 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size)
def _adjust_shard_indexes_for_packing(
shard_size, shard_offset, packed_factor, marlin_tile_size
):
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
shard_size = round(shard_size // packed_factor)
shard_offset = round(shard_offset // packed_factor)
if marlin_tile_size is not None:
return _adjust_shard_indexes_for_marlin(
shard_size=shard_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment