Commit 3b9aa746 authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev' into 'v0.15.1-dev-lxh'

# Conflicts:
#   vllm/model_executor/layers/fused_moe/fused_moe.py
parents 03a3c522 02a1e691
...@@ -921,12 +921,12 @@ class rocm_aiter_ops: ...@@ -921,12 +921,12 @@ class rocm_aiter_ops:
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod @classmethod
@if_aiter_supported # @if_aiter_supported
def is_fused_moe_enabled(cls) -> bool: def is_fused_moe_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FMOE_ENABLED return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod @classmethod
@if_aiter_supported # @if_aiter_supported
def is_fusion_moe_shared_experts_enabled(cls) -> bool: def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
......
...@@ -1055,7 +1055,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1055,7 +1055,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use aiter triton fp4 bmm kernel # Whether to use aiter triton fp4 bmm kernel
# By default is enabled. # By default is enabled.
"VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( "VLLM_ROCM_USE_AITER_FP4BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "False").lower() in ("true", "1")
), ),
# Use AITER triton unified attention for V1 attention # Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
......
...@@ -215,6 +215,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -215,6 +215,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights, get_and_maybe_dequant_weights,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import get_gcn_arch_name
from vllm.utils.flashinfer import has_nvidia_artifactory from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -2115,7 +2116,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2115,7 +2116,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
scale=layer._k_scale, scale=layer._k_scale,
) )
if fp8_attention: if fp8_attention and get_gcn_arch_name() == "gfx938":
kv_cache = kv_cache.view(current_platform.fp8_dtype()) kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill: if has_prefill:
...@@ -2185,7 +2186,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2185,7 +2186,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
if fp8_attention: if fp8_attention and get_gcn_arch_name() == "gfx938":
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
decode_q = self._decode_concat_quant_fp8_op( decode_q = self._decode_concat_quant_fp8_op(
......
...@@ -1613,8 +1613,8 @@ def fused_experts( ...@@ -1613,8 +1613,8 @@ def fused_experts(
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: Optional[torch.Tensor] = None, i_q: torch.Tensor | None = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: torch.Tensor | None = None # TODO:wjl
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......
...@@ -335,8 +335,10 @@ class GroupedTopKRouter(BaseRouter): ...@@ -335,8 +335,10 @@ class GroupedTopKRouter(BaseRouter):
rocm_aiter_grouped_topk, rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
) )
enable_shared_experts_fusion = True
else: else:
grouped_topk_impl = grouped_topk grouped_topk_impl = grouped_topk
enable_shared_experts_fusion = False
if self.use_fused_gate: if self.use_fused_gate:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
...@@ -347,7 +349,7 @@ class GroupedTopKRouter(BaseRouter): ...@@ -347,7 +349,7 @@ class GroupedTopKRouter(BaseRouter):
self.num_expert_group, self.num_expert_group,
self.topk_group, self.topk_group,
self.top_k, self.top_k,
0, self.num_fused_shared_experts if enable_shared_experts_fusion else 0,
self.routed_scaling_factor, self.routed_scaling_factor,
) )
else: else:
...@@ -358,7 +360,7 @@ class GroupedTopKRouter(BaseRouter): ...@@ -358,7 +360,7 @@ class GroupedTopKRouter(BaseRouter):
self.topk_group, self.topk_group,
self.top_k, self.top_k,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
n_share_experts_fusion=0, n_share_experts_fusion = (self.num_fused_shared_experts if enable_shared_experts_fusion else 0),
) )
else: else:
topk_weights, topk_ids = grouped_topk_impl( topk_weights, topk_ids = grouped_topk_impl(
......
...@@ -335,7 +335,7 @@ class FusedRMSNormQuant(nn.Module): ...@@ -335,7 +335,7 @@ class FusedRMSNormQuant(nn.Module):
quant_dtype: torch.dtype = torch.int8, quant_dtype: torch.dtype = torch.int8,
update_input: Optional[bool] = True update_input: Optional[bool] = True
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = torch.ops.vllm.fused_rmsquant(input=x, i_q, i_s = torch.ops.vllm.fused_rmsquant_customer_impl(input=x,
weight=self.weight, weight=self.weight,
epsilon=self.variance_epsilon, epsilon=self.variance_epsilon,
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
...@@ -383,9 +383,9 @@ def fused_rmsquant_fake( ...@@ -383,9 +383,9 @@ def fused_rmsquant_fake(
# customer_lib = Library("customer_", "FRAGMENT") # customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op( direct_register_custom_op(
op_name="fused_rmsquant", op_name="fused_rmsquant_customer_impl",
op_func=fused_rmsquant_impl, op_func=fused_rmsquant_impl,
mutates_args=[], mutates_args=["input", "residual"],
fake_impl=fused_rmsquant_fake, fake_impl=fused_rmsquant_fake,
) )
......
...@@ -711,32 +711,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -711,32 +711,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear. will be treated as a "Replicated" MergedLinear.
""" """
def forward(
self,
input_,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
......
...@@ -1256,7 +1256,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1256,7 +1256,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None, i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
......
...@@ -307,6 +307,8 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -307,6 +307,8 @@ class SlimQuantW4A8Int8MoEMethod:
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts( return fused_experts(
......
...@@ -224,6 +224,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -224,6 +224,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
......
...@@ -49,7 +49,7 @@ def sparse_attn_indexer( ...@@ -49,7 +49,7 @@ def sparse_attn_indexer(
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run # Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous( current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else torch.bfloat16), ((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k.dtype,),
((total_seq_lens, 4), torch.uint8), ((total_seq_lens, 4), torch.uint8),
) )
return sparse_attn_indexer_fake( return sparse_attn_indexer_fake(
......
...@@ -324,8 +324,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -324,8 +324,12 @@ class DeepseekV2MoE(nn.Module):
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
gate=self.gate, gate=self.gate,
num_experts=config.n_routed_experts, # num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, # top_k=config.num_experts_per_tok,
num_experts=config.n_routed_experts
+ (config.n_shared_experts if self.is_fusion_moe_shared_experts_enabled else 0),
top_k = config.num_experts_per_tok
+ (config.n_shared_experts if self.is_fusion_moe_shared_experts_enabled else 0),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False, reduce_results=False,
......
...@@ -121,6 +121,10 @@ def on_gfx9() -> bool: ...@@ -121,6 +121,10 @@ def on_gfx9() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936", "gfx938"]) return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936", "gfx938"])
@cache
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
@cache @cache
def on_gfx942() -> bool: def on_gfx942() -> bool:
......
...@@ -310,6 +310,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -310,6 +310,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal=True, causal=True,
descale_q=layer._q_scale.reshape(1), descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1), descale_k=layer._k_scale.reshape(1),
kv_cache_dtype=self.kv_cache_dtype,
) )
else: else:
o, lse = flash_mla_with_kvcache( o, lse = flash_mla_with_kvcache(
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import get_gcn_arch_name
logger = init_logger(__name__) logger = init_logger(__name__)
if current_platform.is_cuda(): if current_platform.is_cuda():
...@@ -136,7 +136,7 @@ def get_mla_metadata_dense_fp8( ...@@ -136,7 +136,7 @@ def get_mla_metadata_dense_fp8(
cache_seqlens, cache_seqlens,
num_q_tokens_per_head_k, num_q_tokens_per_head_k,
num_heads_k, num_heads_k,
16, # 16,
) )
else: else:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8( return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
...@@ -158,12 +158,14 @@ def flash_mla_with_kvcache_fp8( ...@@ -158,12 +158,14 @@ def flash_mla_with_kvcache_fp8(
causal: bool = False, causal: bool = False,
descale_q: torch.Tensor | None = None, descale_q: torch.Tensor | None = None,
descale_k: torch.Tensor | None = None, descale_k: torch.Tensor | None = None,
kv_cache_dtype: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not _is_flashmla_available()[0]: if not _is_flashmla_available()[0]:
_raise_flashmla_unavailable() _raise_flashmla_unavailable()
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if get_gcn_arch_name() == "gfx938":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q, q,
k_cache, k_cache,
...@@ -178,6 +180,21 @@ def flash_mla_with_kvcache_fp8( ...@@ -178,6 +180,21 @@ def flash_mla_with_kvcache_fp8(
descale_q, descale_q,
descale_k, descale_k,
) )
else:
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_k,
kv_cache_dtype,
)
else: else:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment