Unverified Commit ef4a8097 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Rename flashmla kernel options of nsa backend for better readability (#11876)

parent ebff4ee6
......@@ -228,6 +228,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--sampling-backend` | Choose the kernels for sampling layers. | None |
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
| `--mm-attention-backend` | Set multimodal attention backend. | None |
| `--nsa-prefill-backend` | Prefill attention implementation for nsa backend. | `flashmla_sparse` |
| `--nsa-decode-backend` | Decode attention implementation for nsa backend. | `flashmla_kv` |
## Speculative decoding
......
......@@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
)
_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
]
_NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
NSA_PREFILL_IMPL: _NSA_IMPL_T
NSA_DECODE_IMPL: _NSA_IMPL_T
......@@ -181,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend):
self.req_to_token = model_runner.req_to_token_pool.req_to_token
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
......@@ -336,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1,
)
if NSA_DECODE_IMPL == "flashmla_decode"
if NSA_DECODE_IMPL == "flashmla_kv"
else None
),
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
......@@ -383,7 +381,7 @@ class NativeSparseAttnBackend(AttentionBackend):
),
seq_len_q=1,
)
if NSA_DECODE_IMPL == "flashmla_decode"
if NSA_DECODE_IMPL == "flashmla_kv"
else None
),
}
......@@ -421,7 +419,7 @@ class NativeSparseAttnBackend(AttentionBackend):
seqlens_expanded = cache_seqlens_int32
nsa_extend_seq_lens_list = [1] * num_tokens
if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, num_tokens + 1))
......@@ -478,7 +476,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
......@@ -534,7 +532,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
nsa_extend_seq_lens_list = [1] * bs
if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
......@@ -712,7 +710,7 @@ class NativeSparseAttnBackend(AttentionBackend):
else:
assert metadata.real_page_table is metadata.page_table_1
if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = metadata.flashmla_metadata.slice(
slice(0, seqlens_expanded_size + 1)
)
......@@ -803,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend):
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_prefill":
elif NSA_PREFILL_IMPL == "flashmla_sparse":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
return self._forward_flashmla_sparse(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_decode":
elif NSA_PREFILL_IMPL == "flashmla_kv":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
return self._forward_flashmla_kv(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
......@@ -897,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend):
page_size=1,
)
if NSA_DECODE_IMPL == "flashmla_prefill":
if NSA_DECODE_IMPL == "flashmla_sparse":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
return self._forward_flashmla_sparse(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "flashmla_decode":
elif NSA_DECODE_IMPL == "flashmla_kv":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
return self._forward_flashmla_kv(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
......@@ -998,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
return o # type: ignore
def _forward_flashmla_prefill(
def _forward_flashmla_sparse(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
......@@ -1017,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
return o
def _forward_flashmla_decode(
def _forward_flashmla_kv(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
......
......@@ -128,7 +128,7 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
DEFAULT_LORA_EVICTION_POLICY = "lru"
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
......@@ -324,8 +324,8 @@ class ServerArgs:
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None
nsa_prefill: str = "flashmla_prefill"
nsa_decode: str = "fa3"
nsa_prefill_backend: str = "flashmla_sparse"
nsa_decode_backend: str = "fa3"
# Speculative decoding
enable_beta_spec: bool = False
......@@ -1024,10 +1024,10 @@ class ServerArgs:
logger.warning("Setting KV cache dtype to fp8.")
if self.kv_cache_dtype == "fp8_e4m3":
self.nsa_prefill = "flashmla_decode"
self.nsa_decode = "flashmla_decode"
self.nsa_prefill_backend = "flashmla_kv"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting NSA backend to flashmla_decode for FP8 KV Cache."
"Setting NSA backend to flashmla_kv for FP8 KV Cache."
)
# Logging env vars for NSA
......@@ -2356,14 +2356,14 @@ class ServerArgs:
help="Set multimodal attention backend.",
)
parser.add_argument(
"--nsa-prefill",
default=ServerArgs.nsa_prefill,
"--nsa-prefill-backend",
default=ServerArgs.nsa_prefill_backend,
type=str,
choices=NSA_CHOICES,
)
parser.add_argument(
"--nsa-decode",
default=ServerArgs.nsa_decode,
"--nsa-decode-backend",
default=ServerArgs.nsa_decode_backend,
type=str,
choices=NSA_CHOICES,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment