Unverified Commit f4a0921c authored by danisereb's avatar danisereb Committed by GitHub
Browse files

[Performance] Tune Mamba selective scan kernel for B200 (#32873)


Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 208c5625
...@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
...@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1, dim=-1,
) )
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp):
state_batch_indices=state_indices_tensor_d_input, state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output, dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
is_blackwell=self.is_blackwell,
) )
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
......
...@@ -286,6 +286,7 @@ def selective_state_update( ...@@ -286,6 +286,7 @@ def selective_state_update(
out=None, out=None,
num_accepted_tokens=None, num_accepted_tokens=None,
cu_seqlens=None, cu_seqlens=None,
is_blackwell=False,
): ):
""" """
Argument: Argument:
...@@ -391,17 +392,26 @@ def selective_state_update( ...@@ -391,17 +392,26 @@ def selective_state_update(
if dst_state_batch_indices is not None if dst_state_batch_indices is not None
else (0, 0) else (0, 0)
) )
# We don't want autotune since it will overwrite the state # We don't want autotune since it will overwrite the state.
# We instead tune by hand. # We instead tune by hand based on dstate.
BLOCK_SIZE_M, num_warps = (
(32, 4) # Default
if dstate <= 16 BLOCK_SIZE_M, num_warps = 4, 8
else (
(16, 4) if dstate <= 16:
if dstate <= 32 BLOCK_SIZE_M, num_warps = 32, 4
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) elif dstate <= 32:
) BLOCK_SIZE_M, num_warps = 16, 4
) elif dstate <= 64:
BLOCK_SIZE_M, num_warps = 8, 4
else:
# dstate > 64
if is_blackwell:
# Optimized for B200 with dstate>64
BLOCK_SIZE_M, num_warps = 32, 8
elif dstate <= 128:
BLOCK_SIZE_M, num_warps = 4, 4
tie_hdim = ( tie_hdim = (
A.stride(-1) == 0 A.stride(-1) == 0
and A.stride(-2) == 0 and A.stride(-2) == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment