Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -6,7 +6,7 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
......@@ -936,7 +936,7 @@ def enable_batch_invariant_mode():
# Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"):
if (
current_platform.is_device_capability(100)
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89)
):
......@@ -1004,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT
def override_envs_for_invariance():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
supported_backends = [
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA,
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# AttentionBackendEnum.FLASHMLA,
# AttentionBackendEnum.FLEX_ATTENTION, # IMA issue
# AttentionBackendEnum.FLASHINFER_MLA, # PR #28967
]
if curr_attn_backend not in supported_backends:
if attention_backend not in supported_backends:
supported_names = [b.name for b in supported_backends]
backend_name = attention_backend.name if attention_backend else None
error = (
"VLLM batch_invariant mode requires an attention backend in "
f"{supported_backends}, but got '{curr_attn_backend}'. "
"Please set the 'VLLM_ATTENTION_BACKEND' environment variable "
"to one of the supported backends before enabling batch_invariant."
f"{supported_names}, but got '{backend_name}'. "
"Please use --attention-backend or attention_config to set "
"one of the supported backends before enabling batch_invariant."
)
raise RuntimeError(error)
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
if attention_backend != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
......@@ -1050,10 +1053,12 @@ def override_envs_for_invariance():
os.environ["VLLM_USE_AOT_COMPILE"] = "0"
def init_batch_invariance():
def init_batch_invariance(
attention_backend: AttentionBackendEnum | None,
):
# this will hit all the csrc overrides as well
if vllm_is_batch_invariant():
override_envs_for_invariance()
override_envs_for_invariance(attention_backend)
enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding
......
......@@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100)
return (
is_deep_gemm_e8m0_used()
and current_platform.is_device_capability_family(100)
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
......
......@@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config(
)
def gptq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
weight_bits: int,
group_size: int,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
):
"""
Construct a quant config for gptq marlin quantization.
"""
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
# Activations are NOT quantized for GPTQ (fp16/bf16)
a_shape = w_shape # Same as weight shape for alignment
# Determine weight dtype
if weight_bits == 4:
weight_dtype = "int4"
elif weight_bits == 8:
weight_dtype = torch.int8
else:
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
)
def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
......@@ -700,6 +736,42 @@ def int4_w4afp8_moe_quant_config(
)
def awq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: torch.Tensor | None,
w2_zp: torch.Tensor | None,
weight_bits: int,
group_size: int,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for awq marlin quantization.
"""
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
# Activations are NOT quantized for AWQ (fp16/bf16)
a_shape = w_shape # Same as weight shape for alignment
# Determine weight dtype
if weight_bits == 4:
weight_dtype = "int4"
elif weight_bits == 8:
weight_dtype = torch.int8
else:
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
)
def biased_moe_quant_config(
w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
......
......@@ -460,7 +460,6 @@ def cutlass_moe_fp8(
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
parallel_config=None,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
......@@ -538,7 +537,6 @@ def cutlass_moe_fp8(
c_strides2=c_strides2,
quant_config=quant_config,
),
parallel_config=parallel_config,
)
return fn(
......
......@@ -293,7 +293,7 @@ def deep_gemm_moe_fp8(
expert_map: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
apply_router_weight_on_input=False,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
......
......@@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1(
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
offs = start_m + off_expert
mask = offs < cur_expert_token_num
tl.store(
m_indices_start_ptr + start_m + off_expert,
m_indices_start_ptr + offs,
cur_expert,
mask=mask,
)
......@@ -366,12 +372,17 @@ def deepgemm_moe_permute(
(M_sum, H // block_k), device=device, dtype=torch.float32
)
maybe_has_empty_blocks = (expert_tokens_meta is None) or (
expert_tokens_meta.expert_num_tokens_cpu is None
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark
# completely invalid / padded blocks that should be skipped. We always
# initialize expert_ids to -1 so any row that is not explicitly written
# by the scatter kernel will be treated as invalid and skipped by
# DeepGEMM's scheduler.
expert_ids = torch.full(
(M_sum,),
fill_value=-1,
device=device,
dtype=torch.int32,
)
expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty
expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
expert_num_tokens = None
......
......@@ -903,12 +903,11 @@ def get_moe_configs(
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
(
logger.warning_once(
"Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"
),
config_file_paths,
"Config file not found at %s",
", ".join(config_file_paths),
scope="local",
)
return None
......
......@@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None,
) -> "FusedMoEModularMethod":
parallel_config = getattr(
getattr(moe_layer, "vllm_config", None),
"parallel_config",
None,
)
return FusedMoEModularMethod(
old_quant_method,
FusedMoEModularKernel(
......@@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
getattr(moe_layer, "shared_experts_stream", None),
parallel_config=parallel_config,
moe_parallel_config=moe_layer.moe_parallel_config,
),
)
......
......@@ -371,7 +371,9 @@ class FusedMoE(CustomOp):
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
logger.info_once(
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
......@@ -891,7 +893,7 @@ class FusedMoE(CustomOp):
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
......@@ -1222,10 +1224,14 @@ class FusedMoE(CustomOp):
if full_load:
shard_dim += 1
# Materialize GGUF UninitializedParameter
# Materialize GGUF UninitializedParameter accounting merged weights
if is_gguf_weight and isinstance(param, UninitializedParameter):
# To materialize a tensor, we must have full shape including
# number of experts, making this portion to require `full_load`.
assert full_load
final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]:
# w1 and w3 are merged per expert.
if shard_id in {"w1", "w3"}:
final_shape[1] *= 2
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
......@@ -1578,6 +1584,14 @@ class FusedMoE(CustomOp):
f"EPLB is not supported for {self.quant_method.method_name}."
)
def valid_grouping() -> bool:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts = router_logits.shape[-1]
if num_experts <= self.num_expert_group:
return False
return num_experts % self.num_expert_group == 0
indices_type = self.quant_method.topk_indices_dtype
# Check if we should use a routing simulation strategy
......@@ -1592,7 +1606,7 @@ class FusedMoE(CustomOp):
)
# DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk:
elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None
assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
......
......@@ -10,10 +10,12 @@ from typing import final
import torch
import vllm.envs as envs
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
count_expert_num_tokens,
......@@ -22,12 +24,12 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook,
dbo_yield,
)
from vllm.v1.worker.workspace import current_workspace_manager
logger = init_logger(__name__)
......@@ -661,25 +663,6 @@ def _slice_scales(
return None
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
assert shape != ()
shape_numel = prod(shape)
if (
self.buffer is None
or self.buffer.numel() < shape_numel
or self.buffer.device != device
or self.buffer.dtype != dtype
):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
return self.buffer[:shape_numel].view(*shape)
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
......@@ -694,29 +677,13 @@ class FusedMoEModularKernel(torch.nn.Module):
objects.
"""
class SharedBuffers:
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
self.workspace2 = SharedResizableBuffer()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
shared_experts_stream: torch.cuda.Stream | None = None,
parallel_config: ParallelConfig | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
......@@ -724,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module):
self.shared_experts = shared_experts
self.shared_experts_stream = shared_experts_stream
# cache whether this worker is using DP+EP
if parallel_config is None:
parallel_config = get_current_vllm_config().parallel_config
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
# if not provided, assume this kernel is
# running in a non-DP+EP context
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
self.is_dp_ep = (
parallel_config.data_parallel_size > 1
and parallel_config.enable_expert_parallel
moe_parallel_config is not None
and moe_parallel_config.dp_size > 1
and moe_parallel_config.use_ep
)
self._post_init_setup()
......@@ -806,10 +776,6 @@ class FusedMoEModularKernel(torch.nn.Module):
assert M_full > 0 and M_chunk > 0
num_chunks, _ = self._chunk_info(M_full)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for
......@@ -832,14 +798,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta,
)
)
buffers.workspace13.get(
max_workspace_13, device=device, dtype=workspace_dtype
)
buffers.workspace2.get(
max_workspace_2, device=device, dtype=workspace_dtype
)
buffers.fused_out.get(
max_fused_out_shape, device=device, dtype=workspace_dtype
current_workspace_manager().get_simultaneous(
(max_workspace_13, workspace_dtype),
(max_workspace_2, workspace_dtype),
(max_fused_out_shape, out_dtype),
)
# Get intermediate workspace shapes based off the chunked M size.
......@@ -866,22 +829,23 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = buffers.workspace13.get(
workspace13_shape, device=device, dtype=workspace_dtype
)
workspace2 = buffers.workspace2.get(
workspace2_shape, device=device, dtype=workspace_dtype
)
# Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard
# format experts and with experts that have empty workspaces.
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
)
fused_out = _resize_cache(workspace13, fused_out_shape)
else:
fused_out = buffers.fused_out.get(
fused_out_shape, device=device, dtype=out_dtype
workspace13, workspace2, fused_out = (
current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
(fused_out_shape, out_dtype),
)
)
return workspace13, workspace2, fused_out
......
......@@ -30,8 +30,8 @@ class SharedFusedMoE(FusedMoE):
# Disable shared expert overlap if:
# - we are using eplb, because of correctness issues
# - we are using flashinfer with DP, since there nothint to gain
# - we are using marlin kjernels
# - we are using flashinfer with DP, since there nothing to gain
# - we are using marlin kernels
self.use_overlapped = (
use_overlapped
and not (
......
......@@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
}
)
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full", intermediate_size_per_partition
)
self.is_k_full = intermediate_size_per_partition == intermediate_size_full
w13_qweight = Parameter(
torch.empty(
num_experts,
......@@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but AWQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
......@@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return None
from vllm.model_executor.layers.fused_moe.config import (
awq_marlin_moe_quant_config,
)
return awq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if self.quant_config.zero_point
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if self.quant_config.zero_point
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for AWQ-Marlin MoE.
Returns MarlinExperts configured for AWQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, AWQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, AWQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = getattr(layer, "w13_g_idx", None)
w2_g_idx = getattr(layer, "w2_g_idx", None)
w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None)
w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for AWQ
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply(
self,
......
......@@ -1266,9 +1266,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
parallel_config=getattr(
getattr(layer, "vllm_config", None), "parallel_config", None
),
)
else:
......
......@@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
# don't restrict as emulations
return 80
def create_weights(
......
......@@ -137,7 +137,7 @@ def get_fp8_moe_backend(
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability(100)
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
......@@ -148,7 +148,7 @@ def get_fp8_moe_backend(
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability(100):
if block_quant and current_platform.is_device_capability_family(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
......@@ -193,7 +193,7 @@ def get_fp8_moe_backend(
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and current_platform.is_device_capability_family(100)
and block_quant
):
logger.info_once(
......@@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
if self.is_checkpoint_fp8_serialized:
moe_quant_method = Fp8MoEMethod(self, layer)
else:
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
......@@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.orig_dtype = params_dtype
layer.weight_block_size = None
if self.quant_config.is_checkpoint_fp8_serialized:
assert self.quant_config.is_checkpoint_fp8_serialized
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
......@@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f"weight quantization block_k = {block_k}."
)
# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if not self.quant_config.is_checkpoint_fp8_serialized:
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
......@@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
......@@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False
)
# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
replace_parameter(
layer,
"w13_weight_scale",
torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device,
),
)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
......@@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return result
class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(quant_config, layer)
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
assert quant_config.weight_block_size is None
assert self.flashinfer_moe_backend is None
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
layer.w13_input_scale = None
layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
# Reshuffle weights for AITER if needed.
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed.
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
......
......@@ -33,6 +33,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
......@@ -52,6 +53,11 @@ class GGUFConfig(QuantizationConfig):
return "gguf"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
# GGUF dequantization kernels use half precision (fp16) internally.
# bfloat16 has precision issues on Blackwell devices.
if current_platform.has_device_capability(100):
logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
return [torch.half, torch.float32]
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
......
......@@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
......@@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return None
from vllm.model_executor.layers.fused_moe.config import (
gptq_marlin_moe_quant_config,
)
return gptq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if not self.quant_config.is_sym
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if not self.quant_config.is_sym
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for GPTQ-Marlin MoE.
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if self.quant_config.weight_bits == 8:
raise NotImplementedError(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = (
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
)
w2_g_idx = (
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
)
w13_g_idx_sort_indices = (
getattr(layer, "w13_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
w2_g_idx_sort_indices = (
getattr(layer, "w2_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for GPTQ
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply(
self,
......
......@@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig:
class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
raise NotImplementedError
@classmethod
......@@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC):
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
assert self.is_supported()
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
......
......@@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
......@@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
type[ScaledMMLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"
)
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
continue
# If the current platform uses compute_capability,
# make sure the kernel supports the compute cability.
if compute_capability is not None:
kernel_min_capability = kernel.get_min_capability()
if (
kernel_min_capability is not None
and kernel_min_capability > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel_min_capability}, current compute capability "
f"is {compute_capability}"
)
# make sure the kernel supports the compute capability.
is_supported, reason = kernel.is_supported(compute_capability)
if not is_supported:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
can_implement, reason = kernel.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError(
"Failed to find a kernel that can implement the "
......
......@@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.",
)
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 90:
return False, f"requires capability 90, got {compute_capability}"
try:
import aiter # noqa: F401 # deliberately attempt to import aiter
......@@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.",
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (rocm_aiter_ops.is_linear_enabled()):
if not rocm_aiter_ops.is_linear_enabled():
return (
False,
"AiterScaledMMLinearKernel is disabled. "
......@@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return (
False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment