Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
...@@ -6,7 +6,7 @@ from typing import Any ...@@ -6,7 +6,7 @@ from typing import Any
import torch import torch
import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -936,7 +936,7 @@ def enable_batch_invariant_mode(): ...@@ -936,7 +936,7 @@ def enable_batch_invariant_mode():
# Batch invariant matmuls are no longer needed after cublas overrides # Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"): if not is_torch_equal_or_newer("2.10.0.dev"):
if ( if (
current_platform.is_device_capability(100) current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(80) or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89) or current_platform.is_device_capability(89)
): ):
...@@ -1004,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool: ...@@ -1004,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT return VLLM_BATCH_INVARIANT
def override_envs_for_invariance(): def override_envs_for_invariance(
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND attention_backend: AttentionBackendEnum | None,
):
supported_backends = [ supported_backends = [
"FLASH_ATTN", # best supported backend AttentionBackendEnum.FLASH_ATTN, # best supported backend
"FLASHINFER", AttentionBackendEnum.FLASHINFER,
"FLASH_ATTN_MLA", AttentionBackendEnum.FLASH_ATTN_MLA,
"TRITON_MLA", AttentionBackendEnum.TRITON_MLA,
# Not yet supported MLA backends # Not yet supported MLA backends
# "FLASHMLA", # AttentionBackendEnum.FLASHMLA,
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance # AttentionBackendEnum.FLEX_ATTENTION, # IMA issue
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 # AttentionBackendEnum.FLASHINFER_MLA, # PR #28967
] ]
if curr_attn_backend not in supported_backends: if attention_backend not in supported_backends:
supported_names = [b.name for b in supported_backends]
backend_name = attention_backend.name if attention_backend else None
error = ( error = (
"VLLM batch_invariant mode requires an attention backend in " "VLLM batch_invariant mode requires an attention backend in "
f"{supported_backends}, but got '{curr_attn_backend}'. " f"{supported_names}, but got '{backend_name}'. "
"Please set the 'VLLM_ATTENTION_BACKEND' environment variable " "Please use --attention-backend or attention_config to set "
"to one of the supported backends before enabling batch_invariant." "one of the supported backends before enabling batch_invariant."
) )
raise RuntimeError(error) raise RuntimeError(error)
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: if attention_backend != supported_backends[0]:
warning = ( warning = (
"You are using a decode-invariant form of batch invariance. " "You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode." "This will not be invariant between prefill and decode."
...@@ -1050,10 +1053,12 @@ def override_envs_for_invariance(): ...@@ -1050,10 +1053,12 @@ def override_envs_for_invariance():
os.environ["VLLM_USE_AOT_COMPILE"] = "0" os.environ["VLLM_USE_AOT_COMPILE"] = "0"
def init_batch_invariance(): def init_batch_invariance(
attention_backend: AttentionBackendEnum | None,
):
# this will hit all the csrc overrides as well # this will hit all the csrc overrides as well
if vllm_is_batch_invariant(): if vllm_is_batch_invariant():
override_envs_for_invariance() override_envs_for_invariance(attention_backend)
enable_batch_invariant_mode() enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding # Disable TF32 for batch invariance - it causes non-deterministic rounding
......
...@@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
DeepGemm supports packed ue8m0 activation scales format in devices == sm100 DeepGemm supports packed ue8m0 activation scales format in devices == sm100
""" """
return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100) return (
is_deep_gemm_e8m0_used()
and current_platform.is_device_capability_family(100)
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
......
...@@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config( ...@@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config(
) )
def gptq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
weight_bits: int,
group_size: int,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
):
"""
Construct a quant config for gptq marlin quantization.
"""
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
# Activations are NOT quantized for GPTQ (fp16/bf16)
a_shape = w_shape # Same as weight shape for alignment
# Determine weight dtype
if weight_bits == 4:
weight_dtype = "int4"
elif weight_bits == 8:
weight_dtype = torch.int8
else:
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
)
def mxfp4_w4a16_moe_quant_config( def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"], w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"],
...@@ -700,6 +736,42 @@ def int4_w4afp8_moe_quant_config( ...@@ -700,6 +736,42 @@ def int4_w4afp8_moe_quant_config(
) )
def awq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: torch.Tensor | None,
w2_zp: torch.Tensor | None,
weight_bits: int,
group_size: int,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for awq marlin quantization.
"""
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
# Activations are NOT quantized for AWQ (fp16/bf16)
a_shape = w_shape # Same as weight shape for alignment
# Determine weight dtype
if weight_bits == 4:
weight_dtype = "int4"
elif weight_bits == 8:
weight_dtype = torch.int8
else:
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
)
def biased_moe_quant_config( def biased_moe_quant_config(
w1_bias: torch.Tensor | None, w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None, w2_bias: torch.Tensor | None,
......
...@@ -460,7 +460,6 @@ def cutlass_moe_fp8( ...@@ -460,7 +460,6 @@ def cutlass_moe_fp8(
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
parallel_config=None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a a8w8-quantized Mixture of Experts (MoE) layer This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...@@ -538,7 +537,6 @@ def cutlass_moe_fp8( ...@@ -538,7 +537,6 @@ def cutlass_moe_fp8(
c_strides2=c_strides2, c_strides2=c_strides2,
quant_config=quant_config, quant_config=quant_config,
), ),
parallel_config=parallel_config,
) )
return fn( return fn(
......
...@@ -293,7 +293,7 @@ def deep_gemm_moe_fp8( ...@@ -293,7 +293,7 @@ def deep_gemm_moe_fp8(
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None,
apply_router_weight_on_input=False, apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a a8w8-quantized Mixture of Experts (MoE) layer This function computes a a8w8-quantized Mixture of Experts (MoE) layer
......
...@@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1( ...@@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1(
m_indices_start_ptr = m_indices + cur_expert_start m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E) off_expert = tl.arange(0, BLOCK_E)
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
offs = start_m + off_expert
mask = offs < cur_expert_token_num
tl.store( tl.store(
m_indices_start_ptr + start_m + off_expert, m_indices_start_ptr + offs,
cur_expert, cur_expert,
mask=mask,
) )
...@@ -366,12 +372,17 @@ def deepgemm_moe_permute( ...@@ -366,12 +372,17 @@ def deepgemm_moe_permute(
(M_sum, H // block_k), device=device, dtype=torch.float32 (M_sum, H // block_k), device=device, dtype=torch.float32
) )
maybe_has_empty_blocks = (expert_tokens_meta is None) or ( # DeepGEMM uses negative values in m_indices (here expert_ids) to mark
expert_tokens_meta.expert_num_tokens_cpu is None # completely invalid / padded blocks that should be skipped. We always
# initialize expert_ids to -1 so any row that is not explicitly written
# by the scatter kernel will be treated as invalid and skipped by
# DeepGEMM's scheduler.
expert_ids = torch.full(
(M_sum,),
fill_value=-1,
device=device,
dtype=torch.int32,
) )
expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty
expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32) inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
expert_num_tokens = None expert_num_tokens = None
......
...@@ -903,12 +903,11 @@ def get_moe_configs( ...@@ -903,12 +903,11 @@ def get_moe_configs(
# If no optimized configuration is available, we will use the default # If no optimized configuration is available, we will use the default
# configuration # configuration
logger.warning( logger.warning_once(
( "Using default MoE config. Performance might be sub-optimal! "
"Using default MoE config. Performance might be sub-optimal! " "Config file not found at %s",
"Config file not found at %s" ", ".join(config_file_paths),
), scope="local",
config_file_paths,
) )
return None return None
......
...@@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
) -> "FusedMoEModularMethod": ) -> "FusedMoEModularMethod":
parallel_config = getattr(
getattr(moe_layer, "vllm_config", None),
"parallel_config",
None,
)
return FusedMoEModularMethod( return FusedMoEModularMethod(
old_quant_method, old_quant_method,
FusedMoEModularKernel( FusedMoEModularKernel(
...@@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts, shared_experts,
getattr(moe_layer, "shared_experts_stream", None), getattr(moe_layer, "shared_experts_stream", None),
parallel_config=parallel_config, moe_parallel_config=moe_layer.moe_parallel_config,
), ),
) )
......
...@@ -371,7 +371,9 @@ class FusedMoE(CustomOp): ...@@ -371,7 +371,9 @@ class FusedMoE(CustomOp):
# aux_stream() returns None on non-cuda-alike platforms. # aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream() self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None: if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts") logger.info_once(
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
...@@ -891,7 +893,7 @@ class FusedMoE(CustomOp): ...@@ -891,7 +893,7 @@ class FusedMoE(CustomOp):
# Record that the clone will be used by shared_experts_stream # Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone # to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream()) # NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output. # because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream) hidden_states_clone.record_stream(self.shared_experts_stream)
...@@ -1222,10 +1224,14 @@ class FusedMoE(CustomOp): ...@@ -1222,10 +1224,14 @@ class FusedMoE(CustomOp):
if full_load: if full_load:
shard_dim += 1 shard_dim += 1
# Materialize GGUF UninitializedParameter # Materialize GGUF UninitializedParameter accounting merged weights
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight and isinstance(param, UninitializedParameter):
# To materialize a tensor, we must have full shape including
# number of experts, making this portion to require `full_load`.
assert full_load
final_shape = list(loaded_weight.shape) final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]: # w1 and w3 are merged per expert.
if shard_id in {"w1", "w3"}:
final_shape[1] *= 2 final_shape[1] *= 2
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype) param.materialize(final_shape, dtype=loaded_weight.dtype)
...@@ -1578,6 +1584,14 @@ class FusedMoE(CustomOp): ...@@ -1578,6 +1584,14 @@ class FusedMoE(CustomOp):
f"EPLB is not supported for {self.quant_method.method_name}." f"EPLB is not supported for {self.quant_method.method_name}."
) )
def valid_grouping() -> bool:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts = router_logits.shape[-1]
if num_experts <= self.num_expert_group:
return False
return num_experts % self.num_expert_group == 0
indices_type = self.quant_method.topk_indices_dtype indices_type = self.quant_method.topk_indices_dtype
# Check if we should use a routing simulation strategy # Check if we should use a routing simulation strategy
...@@ -1592,7 +1606,7 @@ class FusedMoE(CustomOp): ...@@ -1592,7 +1606,7 @@ class FusedMoE(CustomOp):
) )
# DeepSeekv2 uses grouped_top_k # DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk: elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None assert self.topk_group is not None
assert self.num_expert_group is not None assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled(): if rocm_aiter_ops.is_fused_moe_enabled():
......
...@@ -10,10 +10,12 @@ from typing import final ...@@ -10,10 +10,12 @@ from typing import final
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
count_expert_num_tokens, count_expert_num_tokens,
...@@ -22,12 +24,12 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -22,12 +24,12 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import ( from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled, dbo_enabled,
dbo_maybe_run_recv_hook, dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_register_recv_hook,
dbo_yield, dbo_yield,
) )
from vllm.v1.worker.workspace import current_workspace_manager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -661,25 +663,6 @@ def _slice_scales( ...@@ -661,25 +663,6 @@ def _slice_scales(
return None return None
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
assert shape != ()
shape_numel = prod(shape)
if (
self.buffer is None
or self.buffer.numel() < shape_numel
or self.buffer.device != device
or self.buffer.dtype != dtype
):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
return self.buffer[:shape_numel].view(*shape)
@final @final
class FusedMoEModularKernel(torch.nn.Module): class FusedMoEModularKernel(torch.nn.Module):
""" """
...@@ -694,29 +677,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -694,29 +677,13 @@ class FusedMoEModularKernel(torch.nn.Module):
objects. objects.
""" """
class SharedBuffers:
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
self.workspace2 = SharedResizableBuffer()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
def __init__( def __init__(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
shared_experts_stream: torch.cuda.Stream | None = None, shared_experts_stream: torch.cuda.Stream | None = None,
parallel_config: ParallelConfig | None = None, moe_parallel_config: FusedMoEParallelConfig | None = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
...@@ -724,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -724,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module):
self.shared_experts = shared_experts self.shared_experts = shared_experts
self.shared_experts_stream = shared_experts_stream self.shared_experts_stream = shared_experts_stream
# cache whether this worker is using DP+EP # prefer an explicit FusedMoEParallelConfig when available (from
if parallel_config is None: # FusedMoE layers / tests).
parallel_config = get_current_vllm_config().parallel_config # if not provided, assume this kernel is
# running in a non-DP+EP context
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
self.is_dp_ep = ( self.is_dp_ep = (
parallel_config.data_parallel_size > 1 moe_parallel_config is not None
and parallel_config.enable_expert_parallel and moe_parallel_config.dp_size > 1
and moe_parallel_config.use_ep
) )
self._post_init_setup() self._post_init_setup()
...@@ -806,10 +776,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -806,10 +776,6 @@ class FusedMoEModularKernel(torch.nn.Module):
assert M_full > 0 and M_chunk > 0 assert M_full > 0 and M_chunk > 0
num_chunks, _ = self._chunk_info(M_full) num_chunks, _ = self._chunk_info(M_full)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for # Force worst-case allocation in profiling run for
...@@ -832,14 +798,11 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -832,14 +798,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, expert_tokens_meta,
) )
) )
buffers.workspace13.get(
max_workspace_13, device=device, dtype=workspace_dtype current_workspace_manager().get_simultaneous(
) (max_workspace_13, workspace_dtype),
buffers.workspace2.get( (max_workspace_2, workspace_dtype),
max_workspace_2, device=device, dtype=workspace_dtype (max_fused_out_shape, out_dtype),
)
buffers.fused_out.get(
max_fused_out_shape, device=device, dtype=workspace_dtype
) )
# Get intermediate workspace shapes based off the chunked M size. # Get intermediate workspace shapes based off the chunked M size.
...@@ -866,22 +829,23 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -866,22 +829,23 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1. # time we need cache3, we're done with cache1.
workspace13 = buffers.workspace13.get(
workspace13_shape, device=device, dtype=workspace_dtype
)
workspace2 = buffers.workspace2.get(
workspace2_shape, device=device, dtype=workspace_dtype
)
# Construct the entire output that can then be processed in chunks. # Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long # Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard # as it is large enough. This will not always be the case for standard
# format experts and with experts that have empty workspaces. # format experts and with experts that have empty workspaces.
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
)
fused_out = _resize_cache(workspace13, fused_out_shape) fused_out = _resize_cache(workspace13, fused_out_shape)
else: else:
fused_out = buffers.fused_out.get( workspace13, workspace2, fused_out = (
fused_out_shape, device=device, dtype=out_dtype current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
(fused_out_shape, out_dtype),
)
) )
return workspace13, workspace2, fused_out return workspace13, workspace2, fused_out
......
...@@ -30,8 +30,8 @@ class SharedFusedMoE(FusedMoE): ...@@ -30,8 +30,8 @@ class SharedFusedMoE(FusedMoE):
# Disable shared expert overlap if: # Disable shared expert overlap if:
# - we are using eplb, because of correctness issues # - we are using eplb, because of correctness issues
# - we are using flashinfer with DP, since there nothint to gain # - we are using flashinfer with DP, since there nothing to gain
# - we are using marlin kjernels # - we are using marlin kernels
self.use_overlapped = ( self.use_overlapped = (
use_overlapped use_overlapped
and not ( and not (
......
...@@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -470,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
} }
) )
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full", intermediate_size_per_partition
)
self.is_k_full = intermediate_size_per_partition == intermediate_size_full
w13_qweight = Parameter( w13_qweight = Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
...@@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -597,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but AWQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Why does this take the intermediate size for size_k? # Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales, s=layer.w13_scales,
...@@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -661,7 +673,88 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
return None from vllm.model_executor.layers.fused_moe.config import (
awq_marlin_moe_quant_config,
)
return awq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if self.quant_config.zero_point
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if self.quant_config.zero_point
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for AWQ-Marlin MoE.
Returns MarlinExperts configured for AWQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, AWQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, AWQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"AWQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = getattr(layer, "w13_g_idx", None)
w2_g_idx = getattr(layer, "w2_g_idx", None)
w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None)
w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for AWQ
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply( def apply(
self, self,
......
...@@ -1266,9 +1266,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1266,9 +1266,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
ab_strides2=self.ab_strides2, ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1, c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2, c_strides2=self.ab_strides1_c_strides2,
parallel_config=getattr(
getattr(layer, "vllm_config", None), "parallel_config", None
),
) )
else: else:
......
...@@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): ...@@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# dont restrict as emulations # don't restrict as emulations
return 80 return 80
def create_weights( def create_weights(
......
...@@ -137,7 +137,7 @@ def get_fp8_moe_backend( ...@@ -137,7 +137,7 @@ def get_fp8_moe_backend(
if ( if (
current_platform.is_cuda() current_platform.is_cuda()
and ( and (
current_platform.is_device_capability(100) current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90) or current_platform.is_device_capability(90)
) )
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and envs.VLLM_USE_FLASHINFER_MOE_FP8
...@@ -148,7 +148,7 @@ def get_fp8_moe_backend( ...@@ -148,7 +148,7 @@ def get_fp8_moe_backend(
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM return Fp8MoeBackend.FLASHINFER_TRTLLM
else: else:
if block_quant and current_platform.is_device_capability(100): if block_quant and current_platform.is_device_capability_family(100):
raise ValueError( raise ValueError(
"FlashInfer FP8 MoE throughput backend does not " "FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use " "support block quantization. Please use "
...@@ -193,7 +193,7 @@ def get_fp8_moe_backend( ...@@ -193,7 +193,7 @@ def get_fp8_moe_backend(
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if ( if (
current_platform.is_cuda() current_platform.is_cuda()
and current_platform.is_device_capability(100) and current_platform.is_device_capability_family(100)
and block_quant and block_quant
): ):
logger.info_once( logger.info_once(
...@@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig): ...@@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
): ):
return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedFusedMoEMethod(layer.moe_config)
moe_quant_method = Fp8MoEMethod(self, layer) if self.is_checkpoint_fp8_serialized:
moe_quant_method = Fp8MoEMethod(self, layer)
else:
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method return moe_quant_method
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
...@@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.orig_dtype = params_dtype layer.orig_dtype = params_dtype
layer.weight_block_size = None layer.weight_block_size = None
if self.quant_config.is_checkpoint_fp8_serialized: assert self.quant_config.is_checkpoint_fp8_serialized
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
if self.block_quant: if self.block_quant:
assert self.weight_block_size is not None assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size layer.weight_block_size = self.weight_block_size
...@@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f"weight quantization block_k = {block_k}." f"weight quantization block_k = {block_k}."
) )
# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if not self.quant_config.is_checkpoint_fp8_serialized:
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
) )
# If loading fp8 checkpoint, pass the weight loaders. set_weight_attrs(w13_weight_scale, extra_weight_attrs)
# If loading an fp16 checkpoint, do not (we will quantize in set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES # INPUT_SCALES
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter( w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False torch.ones(num_experts, dtype=torch.float32), requires_grad=False
) )
...@@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = Parameter( layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False dg_w2_weight_scale_inv, requires_grad=False
) )
# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
replace_parameter(
layer,
"w13_weight_scale",
torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device,
),
)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else: else:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
...@@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return result return result
class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(quant_config, layer)
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
assert quant_config.weight_block_size is None
assert self.flashinfer_moe_backend is None
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
layer.w13_input_scale = None
layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
# Reshuffle weights for AITER if needed.
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed.
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
""" """
Supports loading kv-cache scaling factors from FP8 checkpoints. Supports loading kv-cache scaling factors from FP8 checkpoints.
......
...@@ -33,6 +33,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -33,6 +33,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -52,6 +53,11 @@ class GGUFConfig(QuantizationConfig): ...@@ -52,6 +53,11 @@ class GGUFConfig(QuantizationConfig):
return "gguf" return "gguf"
def get_supported_act_dtypes(self) -> list[torch.dtype]: def get_supported_act_dtypes(self) -> list[torch.dtype]:
# GGUF dequantization kernels use half precision (fp16) internally.
# bfloat16 has precision issues on Blackwell devices.
if current_platform.has_device_capability(100):
logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
return [torch.half, torch.float32]
return [torch.half, torch.bfloat16, torch.float32] return [torch.half, torch.bfloat16, torch.float32]
@classmethod @classmethod
......
...@@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit=is_a_8bit, is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Repack scales # Repack scales
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales, s=layer.w13_scales,
...@@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
return None from vllm.model_executor.layers.fused_moe.config import (
gptq_marlin_moe_quant_config,
)
return gptq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if not self.quant_config.is_sym
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if not self.quant_config.is_sym
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for GPTQ-Marlin MoE.
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if self.quant_config.weight_bits == 8:
raise NotImplementedError(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = (
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
)
w2_g_idx = (
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
)
w13_g_idx_sort_indices = (
getattr(layer, "w13_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
w2_g_idx_sort_indices = (
getattr(layer, "w2_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for GPTQ
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply( def apply(
self, self,
......
...@@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig: ...@@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig:
class ScaledMMLinearKernel(ABC): class ScaledMMLinearKernel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def get_min_capability(cls) -> int: def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
...@@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC): ...@@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC):
azp_adj_param_name: str, azp_adj_param_name: str,
) -> None: ) -> None:
assert self.can_implement(c) assert self.can_implement(c)
assert self.is_supported()
self.config = c self.config = c
self.w_q_name = w_q_param_name self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name self.w_s_name = w_s_param_name
......
...@@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform ...@@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel],
} }
...@@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel( ...@@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
type[ScaledMMLinearKernel]: Chosen kernel. type[ScaledMMLinearKernel]: Chosen kernel.
""" """
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = [] failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]: for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append( failure_reasons.append(f"{kernel.__name__}: disabled by env var")
f" {kernel.__name__} disabled by environment variable"
)
continue continue
# If the current platform uses compute_capability, # If the current platform uses compute_capability,
# make sure the kernel supports the compute cability. # make sure the kernel supports the compute capability.
if compute_capability is not None: is_supported, reason = kernel.is_supported(compute_capability)
kernel_min_capability = kernel.get_min_capability() if not is_supported:
if ( failure_reasons.append(f"{kernel.__name__}: {reason}")
kernel_min_capability is not None continue
and kernel_min_capability > compute_capability
): can_implement, reason = kernel.can_implement(config)
failure_reasons.append( if not can_implement:
f"{kernel.__name__} requires capability " failure_reasons.append(f"{kernel.__name__}: {reason}")
f"{kernel_min_capability}, current compute capability " continue
f"is {compute_capability}"
)
continue
can_implement, failure_reason = kernel.can_implement(config) return kernel
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError( raise ValueError(
"Failed to find a kernel that can implement the " "Failed to find a kernel that can implement the "
......
...@@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig ...@@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod @classmethod
def get_min_capability(cls) -> int: def is_supported(
return 90 cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_rocm(): if not current_platform.is_rocm():
return ( return (
False, False,
"AiterScaledMMLinearKernel requires `aiter` which is not " "AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.", + "currently supported on non-ROCm platform.",
) )
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 90:
return False, f"requires capability 90, got {compute_capability}"
try: try:
import aiter # noqa: F401 # deliberately attempt to import aiter import aiter # noqa: F401 # deliberately attempt to import aiter
...@@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
"AiterScaledMMLinearKernel requires `aiter` which is not " "AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.", + "installed on ROCm.",
) )
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (rocm_aiter_ops.is_linear_enabled()): if not rocm_aiter_ops.is_linear_enabled():
return ( return (
False, False,
"AiterScaledMMLinearKernel is disabled. " "AiterScaledMMLinearKernel is disabled. "
...@@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
) )
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric: if not c.input_symmetric:
return ( return (
False, False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment