Unverified Commit ef4a8097 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Rename flashmla kernel options of nsa backend for better readability (#11876)

parent ebff4ee6
...@@ -228,6 +228,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -228,6 +228,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--sampling-backend` | Choose the kernels for sampling layers. | None | | `--sampling-backend` | Choose the kernels for sampling layers. | None |
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None | | `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
| `--mm-attention-backend` | Set multimodal attention backend. | None | | `--mm-attention-backend` | Set multimodal attention backend. | None |
| `--nsa-prefill-backend` | Prefill attention implementation for nsa backend. | `flashmla_sparse` |
| `--nsa-decode-backend` | Decode attention implementation for nsa backend. | `flashmla_kv` |
## Speculative decoding ## Speculative decoding
......
...@@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor: ...@@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
) )
_NSA_IMPL_T: TypeAlias = Literal[ _NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
]
NSA_PREFILL_IMPL: _NSA_IMPL_T NSA_PREFILL_IMPL: _NSA_IMPL_T
NSA_DECODE_IMPL: _NSA_IMPL_T NSA_DECODE_IMPL: _NSA_IMPL_T
...@@ -181,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -181,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend):
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32) self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
...@@ -336,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -336,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
cache_seqlens=nsa_cache_seqlens_int32, cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, seq_len_q=1,
) )
if NSA_DECODE_IMPL == "flashmla_decode" if NSA_DECODE_IMPL == "flashmla_kv"
else None else None
), ),
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32, nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
...@@ -383,7 +381,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -383,7 +381,7 @@ class NativeSparseAttnBackend(AttentionBackend):
), ),
seq_len_q=1, seq_len_q=1,
) )
if NSA_DECODE_IMPL == "flashmla_decode" if NSA_DECODE_IMPL == "flashmla_kv"
else None else None
), ),
} }
...@@ -421,7 +419,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -421,7 +419,7 @@ class NativeSparseAttnBackend(AttentionBackend):
seqlens_expanded = cache_seqlens_int32 seqlens_expanded = cache_seqlens_int32
nsa_extend_seq_lens_list = [1] * num_tokens nsa_extend_seq_lens_list = [1] * num_tokens
if NSA_DECODE_IMPL == "flashmla_decode": if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[ flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata" "flashmla_metadata"
].slice(slice(0, num_tokens + 1)) ].slice(slice(0, num_tokens + 1))
...@@ -478,7 +476,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -478,7 +476,7 @@ class NativeSparseAttnBackend(AttentionBackend):
) )
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
if NSA_DECODE_IMPL == "flashmla_decode": if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[ flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata" "flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1)) ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
...@@ -534,7 +532,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -534,7 +532,7 @@ class NativeSparseAttnBackend(AttentionBackend):
) )
nsa_extend_seq_lens_list = [1] * bs nsa_extend_seq_lens_list = [1] * bs
if NSA_DECODE_IMPL == "flashmla_decode": if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[ flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata" "flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1)) ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
...@@ -712,7 +710,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -712,7 +710,7 @@ class NativeSparseAttnBackend(AttentionBackend):
else: else:
assert metadata.real_page_table is metadata.page_table_1 assert metadata.real_page_table is metadata.page_table_1
if NSA_DECODE_IMPL == "flashmla_decode": if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = metadata.flashmla_metadata.slice( flashmla_metadata = metadata.flashmla_metadata.slice(
slice(0, seqlens_expanded_size + 1) slice(0, seqlens_expanded_size + 1)
) )
...@@ -803,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -803,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend):
sm_scale=layer.scaling, sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim, v_head_dim=layer.v_head_dim,
) )
elif NSA_PREFILL_IMPL == "flashmla_prefill": elif NSA_PREFILL_IMPL == "flashmla_sparse":
if q_rope is not None: if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1) q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill( return self._forward_flashmla_sparse(
q_all=q_all, q_all=q_all,
kv_cache=kv_cache, kv_cache=kv_cache,
page_table_1=page_table_1, page_table_1=page_table_1,
sm_scale=layer.scaling, sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim, v_head_dim=layer.v_head_dim,
) )
elif NSA_PREFILL_IMPL == "flashmla_decode": elif NSA_PREFILL_IMPL == "flashmla_kv":
if q_rope is not None: if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1) q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode( return self._forward_flashmla_kv(
q_all=q_all, q_all=q_all,
kv_cache=kv_cache, kv_cache=kv_cache,
sm_scale=layer.scaling, sm_scale=layer.scaling,
...@@ -897,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -897,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend):
page_size=1, page_size=1,
) )
if NSA_DECODE_IMPL == "flashmla_prefill": if NSA_DECODE_IMPL == "flashmla_sparse":
if q_rope is not None: if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1) q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill( return self._forward_flashmla_sparse(
q_all=q_all, q_all=q_all,
kv_cache=kv_cache, kv_cache=kv_cache,
page_table_1=page_table_1, page_table_1=page_table_1,
sm_scale=layer.scaling, sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim, v_head_dim=layer.v_head_dim,
) )
elif NSA_DECODE_IMPL == "flashmla_decode": elif NSA_DECODE_IMPL == "flashmla_kv":
if q_rope is not None: if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1) q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode( return self._forward_flashmla_kv(
q_all=q_all, q_all=q_all,
kv_cache=kv_cache, kv_cache=kv_cache,
sm_scale=layer.scaling, sm_scale=layer.scaling,
...@@ -998,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -998,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend):
) )
return o # type: ignore return o # type: ignore
def _forward_flashmla_prefill( def _forward_flashmla_sparse(
self, self,
q_all: torch.Tensor, q_all: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
...@@ -1017,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -1017,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend):
) )
return o return o
def _forward_flashmla_decode( def _forward_flashmla_kv(
self, self,
q_all: torch.Tensor, q_all: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
......
...@@ -128,7 +128,7 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] ...@@ -128,7 +128,7 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
DEFAULT_LORA_EVICTION_POLICY = "lru" DEFAULT_LORA_EVICTION_POLICY = "lru"
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"] NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"] RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
...@@ -324,8 +324,8 @@ class ServerArgs: ...@@ -324,8 +324,8 @@ class ServerArgs:
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None mm_attention_backend: Optional[str] = None
nsa_prefill: str = "flashmla_prefill" nsa_prefill_backend: str = "flashmla_sparse"
nsa_decode: str = "fa3" nsa_decode_backend: str = "fa3"
# Speculative decoding # Speculative decoding
enable_beta_spec: bool = False enable_beta_spec: bool = False
...@@ -1024,10 +1024,10 @@ class ServerArgs: ...@@ -1024,10 +1024,10 @@ class ServerArgs:
logger.warning("Setting KV cache dtype to fp8.") logger.warning("Setting KV cache dtype to fp8.")
if self.kv_cache_dtype == "fp8_e4m3": if self.kv_cache_dtype == "fp8_e4m3":
self.nsa_prefill = "flashmla_decode" self.nsa_prefill_backend = "flashmla_kv"
self.nsa_decode = "flashmla_decode" self.nsa_decode_backend = "flashmla_kv"
logger.warning( logger.warning(
"Setting NSA backend to flashmla_decode for FP8 KV Cache." "Setting NSA backend to flashmla_kv for FP8 KV Cache."
) )
# Logging env vars for NSA # Logging env vars for NSA
...@@ -2356,14 +2356,14 @@ class ServerArgs: ...@@ -2356,14 +2356,14 @@ class ServerArgs:
help="Set multimodal attention backend.", help="Set multimodal attention backend.",
) )
parser.add_argument( parser.add_argument(
"--nsa-prefill", "--nsa-prefill-backend",
default=ServerArgs.nsa_prefill, default=ServerArgs.nsa_prefill_backend,
type=str, type=str,
choices=NSA_CHOICES, choices=NSA_CHOICES,
) )
parser.add_argument( parser.add_argument(
"--nsa-decode", "--nsa-decode-backend",
default=ServerArgs.nsa_decode, default=ServerArgs.nsa_decode_backend,
type=str, type=str,
choices=NSA_CHOICES, choices=NSA_CHOICES,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment