Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -115,3 +116,11 @@ class AttentionBackend(ABC): ...@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def support_triton(self): def support_triton(self):
"""Check if the current backend supports triton.""" """Check if the current backend supports triton."""
return True return True
def get_indexer_metadata(
self,
layer_id: int,
forward_batch: ForwardBatch,
) -> Optional[BaseIndexerMetadata]:
"""Get the indexer metadata. None means don't support indexer."""
return None
...@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
k_descale, v_descale = None, None k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys. if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
and self.fa_impl_ver != 4
):
if layer.k_scale is not None: if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape) k_descale = layer.k_scale.expand(descale_shape)
......
...@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType ...@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_env_var, get_int_env_var,
is_flashinfer_available, is_flashinfer_available,
...@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
...@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
...@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
...@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
if use_ragged: if use_ragged:
...@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
...@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask = None custom_mask = None
else: else:
assert isinstance( assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput) spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
) )
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
......
...@@ -5,13 +5,20 @@ Support attention backend for FlashMLA. ...@@ -5,13 +5,20 @@ Support attention backend for FlashMLA.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
import torch import torch
import triton import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata from flash_mla import flash_mla_with_kvcache, get_mla_metadata
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.layers.attention.nsa.utils import (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
NSA_KV_CACHE_STORE_FP8,
compute_nsa_seqlens,
)
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
...@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.scaling = model_runner.model_config.scaling self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
self.nsa_index_topk = (
get_nsa_index_topk(model_runner.model_config.hf_config)
if self.use_nsa
else None
)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size bs = forward_batch.batch_size
...@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
max_seqlen_pad, max_seqlen_pad,
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = _get_mla_metadata_wrapped(
forward_batch.seq_lens.to(torch.int32), cache_seqlens=forward_batch.seq_lens.to(torch.int32),
self.num_q_heads, seq_len_q=1,
1, num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
) )
self.forward_metadata = FlashMLADecodeMetadata( self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata, mla_metadata,
...@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
max_seqlen_pad, max_seqlen_pad,
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = _get_mla_metadata_wrapped(
seq_lens.to(torch.int32), cache_seqlens=seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, seq_len_q=self.num_draft_tokens,
1, num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
) )
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects # Use FlashMLADecodeMetadata which has the attributes forward_extend expects
...@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
cuda_graph_kv_indices = block_kv_indices cuda_graph_kv_indices = block_kv_indices
if self.num_draft_tokens: if self.num_draft_tokens:
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
torch.ones( _get_mla_metadata_wrapped(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device cache_seqlens=torch.ones(
), max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
self.num_draft_tokens * self.num_q_heads, ),
1, seq_len_q=self.num_draft_tokens,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
) )
else: else:
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
torch.ones( _get_mla_metadata_wrapped(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device cache_seqlens=torch.ones(
), max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
self.num_q_heads, ),
1, seq_len_q=1,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
) )
self.cuda_graph_kv_indices = cuda_graph_kv_indices self.cuda_graph_kv_indices = cuda_graph_kv_indices
...@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = _get_mla_metadata_wrapped(
seq_lens.to(torch.int32), cache_seqlens=seq_lens.to(torch.int32),
self.num_q_heads, seq_len_q=1,
1, num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
...@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = _get_mla_metadata_wrapped(
seq_lens.to(torch.int32), cache_seqlens=seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, seq_len_q=self.num_draft_tokens,
1, num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
...@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = _get_mla_metadata_wrapped(
seq_lens.to(torch.int32), cache_seqlens=seq_lens.to(torch.int32),
self.num_q_heads, seq_len_q=1,
1, num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
...@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = _get_mla_metadata_wrapped(
seq_lens.to(torch.int32), cache_seqlens=seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, seq_len_q=self.num_draft_tokens,
1, num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
...@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
topk_indices: Optional[torch.Tensor] = None,
): ):
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
...@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
) )
bs = forward_batch.batch_size bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_cache = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
if self.data_type == torch.float8_e4m3fn: if (not self.use_nsa) and self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn) reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=reshape_q_fp8, q=reshape_q_fp8,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), k_cache=k_cache,
block_table=self.forward_metadata.block_kv_indices[:bs], block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32), cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
...@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else: else:
block_table = self.forward_metadata.block_kv_indices[:bs]
cache_seqlens = forward_batch.seq_lens.to(torch.int32)
extra_kwargs: Dict
if self.use_nsa:
assert topk_indices is not None
extra_kwargs = dict(
indices=_compute_indices_in_kvcache(
block_table=block_table,
topk_indices=topk_indices.to(torch.int32),
page_size=self.page_size,
),
# doc says it is not used, but if pass in None then error
block_table=block_table,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
)
cache_seqlens = compute_nsa_seqlens(
cache_seqlens, nsa_index_topk=self.nsa_index_topk
)
else:
extra_kwargs = dict(
block_table=block_table,
causal=True,
)
if (
self.use_nsa
and NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
and not NSA_KV_CACHE_STORE_FP8
):
# inefficiently quantize the whole cache
k_cache = quantize_k_cache(k_cache)
# todo: need check all causal True or False? # todo: need check all causal True or False?
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=reshape_q, q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), k_cache=k_cache,
block_table=self.forward_metadata.block_kv_indices[:bs], cache_seqlens=cache_seqlens,
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits, num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, **extra_kwargs,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend: ...@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend:
) )
self.common_template(forward_batch, call_fn) self.common_template(forward_batch, call_fn)
def _get_mla_metadata_wrapped(
*,
cache_seqlens: torch.Tensor,
seq_len_q: int,
num_heads_q: int,
num_heads_k: int,
nsa_index_topk: Optional[int],
):
if nsa_index_topk is not None:
assert nsa_index_topk is not None
return get_mla_metadata(
cache_seqlens=cache_seqlens,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k=seq_len_q * num_heads_q // num_heads_k,
num_heads_k=num_heads_k,
num_heads_q=num_heads_q,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
topk=nsa_index_topk,
)
else:
assert nsa_index_topk is None
return get_mla_metadata(
cache_seqlens=cache_seqlens,
num_heads_per_head_k=seq_len_q * num_heads_q // num_heads_k,
num_heads_k=num_heads_k,
)
# TODO speedup
def _compute_indices_in_kvcache(block_table, topk_indices, page_size):
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
1
)
block_idx = block_table[idx0, topk_indices_safe // page_size]
offset = topk_indices_safe % page_size
indices_in_kvcache = block_idx * page_size + offset
# the kernel requires invalid entry to be -1
assert indices_in_kvcache.shape == topk_indices.shape
indices_in_kvcache[topk_indices == -1] = -1
# return: (batch_size, seqlen_q_ori, topk)
indices_in_kvcache = indices_in_kvcache[:, None, :]
return indices_in_kvcache
...@@ -3,6 +3,7 @@ from typing import Optional, Union ...@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend): ...@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return backend.forward_extend( return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs q, k, v, layer, forward_batch, save_kv_cache, **kwargs
) )
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> Optional[BaseIndexerMetadata]:
backend = self._select_backend(forward_batch.forward_mode)
return backend.get_indexer_metadata(layer_id, forward_batch)
...@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module): ...@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self.rotary_emb = rotary_emb self.rotary_emb = rotary_emb
self.layer_id = layer_id self.layer_id = layer_id
self.has_preprocess_weights = False self.has_preprocess_weights = False
self.dtype = None
self.q_lora_rank = self.q_b_proj.input_size # 1536 self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512 self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128 self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64 self.qk_rope_head_dim = qk_rope_head_dim # 64
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
def preprocess_weights(self, hidden_states): def preprocess_weights(self, hidden_states):
self.dummy = torch.empty( self.dummy = torch.empty(
...@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module): ...@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32) slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping return k_cache, v_cache, slot_mapping
def forward(self, positions, hidden_states, forward_batch, zero_allocator): def forward_absorb_prepare_npu_rms_norm_cache(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch,
zero_allocator,
):
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
self.dtype = hidden_states.dtype
self.cos, self.sin = self.get_sin_cos(positions)
self.kvCache, self.kvCacheRope, self.slotmapping = (
self.get_kv_cache_and_cache_idx(forward_batch)
)
if not self.has_preprocess_weights:
self.has_preprocess_weights = True
cos, sin = self.cos, self.sin
if self.q_lora_rank is not None:
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q_lowrank)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # b*s,n,d
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
latent_cache = latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (B*S,N,1,D)
cache_mode = "PA_BNSD"
self.kvCache = self.kvCache.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.kv_lora_rank,
)
self.kvCacheRope = self.kvCacheRope.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.qk_rope_head_dim,
)
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
latent_cache,
self.kv_a_layernorm.weight,
cos,
sin,
self.slotmapping.to(torch.int64),
self.kvCacheRope,
self.kvCache,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
if not self.has_preprocess_weights: if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states) self.preprocess_weights(hidden_states)
...@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module): ...@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator, zero_allocator,
positions, positions,
) )
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
_is_w8a8 = (
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
and self.qkv_a_proj.quant_method.quantization_config.get_name()
== "w8a8_int8"
)
if _is_w8a8:
return self.forward_mlapo(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb_prepare_npu_rms_norm_cache(
positions, hidden_states, forward_batch, zero_allocator
)
from .topk import fast_topk, fast_topk_transform
__all__ = ["fast_topk", "fast_topk_transform"]
#include <ATen/core/TensorBase.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/python.h>
namespace {
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024;
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
struct FastTopKParams {
const float *__restrict__ input; // [B, input_stride]
int32_t *__restrict__ indices; // [B, TopK]
int32_t *__restrict__ lengths; // [B]
int64_t input_stride;
bool use_tilelang;
};
// when length <= TopK, we can directly write the indices
__device__ void naive_topk_cuda(const float *__restrict__ score,
int32_t *__restrict__ indice, int32_t length) {
const auto tid = threadIdx.x;
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
indice[i] = (i < length) ? i : -1;
}
}
// keep the first `length` entries, set others to -1
__device__ void
naive_topk_transform(const float *__restrict__ score, int32_t length,
int32_t *__restrict__ dst_page_table,
const int32_t *__restrict__ src_page_table) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
}
}
__device__ __forceinline__ uint8_t convert_to_uint8(float x) {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits & 0xFFFF)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ __forceinline__ uint32_t convert_to_uint32(float x) {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? (~bits & 0xFFFFFFFFu) : (bits | 0x80000000u);
}
template <bool Is_Epilogue = false, typename Indexer, typename Loader,
int LENGTH, int MAX_REMAIN>
__device__ __forceinline__ auto
radix_topk(Indexer indexer, Loader loader, uint32_t length, int topk,
int *__restrict__ index, int &__restrict__ s_counter,
int (&__restrict__ s_histogram)[LENGTH],
int &__restrict__ s_remain_cnt,
int (&__restrict__ s_remain_idx)[MAX_REMAIN]) -> int {
constexpr auto RADIX = LENGTH - 1;
static_assert(RADIX > 1 && (RADIX & (RADIX - 1)) == 0,
"RADIX must be power of 2");
static_assert(RADIX <= kThreadsPerBlock);
__shared__ uint32_t s_threshold_bin_id;
const auto tx = threadIdx.x;
if (tx < RADIX + 1)
s_histogram[tx] = 0;
__syncthreads();
/// NOTE: Use uint32_t as the index
for (auto i = tx; i < length; i += kThreadsPerBlock) {
const auto idx = indexer(i);
const auto bin = loader(idx);
::atomicAdd(&s_histogram[bin], 1);
}
__syncthreads();
// cumsum (descending)
if (tx == 0) {
s_histogram[RADIX] = 0;
s_remain_cnt = 0;
for (int i = RADIX - 2; i >= 0; --i) {
s_histogram[i] += s_histogram[i + 1];
}
// threshold bin
for (int i = 0; i < RADIX; i++) {
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
s_threshold_bin_id = i;
break;
}
}
}
__syncthreads();
const auto threshold_bin = s_threshold_bin_id;
const auto new_topk = topk - s_histogram[threshold_bin + 1];
for (auto i = tx; i < length; i += kThreadsPerBlock) {
const auto idx = indexer(i);
const auto bin_id = static_cast<uint32_t>(loader(idx));
if (bin_id > threshold_bin) {
index[::atomicAdd(&s_counter, 1)] = idx;
} else if (bin_id == threshold_bin && new_topk > 0) {
if constexpr (Is_Epilogue) {
index[::atomicAdd(&s_counter, 1)] = idx;
} else {
if (const auto cnt = ::atomicAdd(&s_remain_cnt, 1);
C10_LIKELY(cnt < MAX_REMAIN)) {
s_remain_idx[cnt] = idx;
}
}
}
}
__syncthreads();
return new_topk;
}
__device__ void fast_topk_cuda(const float *__restrict__ input,
int *__restrict__ index, int length,
int topk = TopK) {
constexpr auto RADIX = 256;
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
__shared__ int s_histogram[RADIX + 1];
__shared__ int s_num_input[2];
__shared__ int s_counter;
// allocate for two rounds
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
s_counter = 0;
// collect candidates
const auto indexer = [](int idx) { return idx; };
const auto loader = [&input](int idx) {
return convert_to_uint8(input[idx]);
};
int new_topk = radix_topk(indexer, loader, length, topk, index, s_counter,
s_histogram, s_num_input[0], s_input_idx[0]);
if (new_topk <= 0)
return;
// round 0
const auto indexer_0 = [](int idx) { return s_input_idx[0][idx]; };
const auto loader_0 = [&input](int idx) {
return (convert_to_uint32(input[idx]) >> 24) & 0xFF;
};
new_topk = radix_topk(indexer_0, loader_0, s_num_input[0], new_topk, index,
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
if (new_topk <= 0)
return;
// round 1
const auto indexer_1 = [](int idx) { return s_input_idx[1][idx]; };
const auto loader_1 = [&input](int idx) {
return (convert_to_uint32(input[idx]) >> 16) & 0xFF;
};
new_topk = radix_topk(indexer_1, loader_1, s_num_input[1], new_topk, index,
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
if (new_topk <= 0)
return;
// round 2
const auto loader_2 = [&input](int idx) {
return (convert_to_uint32(input[idx]) >> 8) & 0xFF;
};
new_topk = radix_topk(indexer_0, loader_2, s_num_input[0], new_topk, index,
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
if (new_topk <= 0)
return;
// round 3
const auto loader_3 = [&input](int idx) {
return convert_to_uint32(input[idx]) & 0xFF;
};
// epilogue
radix_topk<true>(indexer_1, loader_3, s_num_input[1], new_topk, index,
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
}
__device__ void fast_topk_cuda_tl(const float *__restrict__ input,
int *__restrict__ index, int length,
int topk = TopK) {
constexpr auto BLOCK_SIZE = 1024;
constexpr auto RADIX = 256;
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
__shared__ int s_threshold_bin_id;
__shared__ int s_histogram[RADIX + 1];
__shared__ int s_num_input[2];
__shared__ int s_counter;
// allocate for two rounds
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
int tx = threadIdx.x;
// stage 1: 8bit coarse histogram
if (tx < RADIX + 1)
s_histogram[tx] = 0;
__syncthreads();
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin = convert_to_uint8(input[idx]);
::atomicAdd(&s_histogram[bin], 1);
}
__syncthreads();
// cumsum (descending)
if (tx == 0) {
for (int i = RADIX - 2; i >= 0; --i) {
s_histogram[i] += s_histogram[i + 1];
}
// threshold bin
for (int i = 0; i < RADIX; i++) {
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
s_threshold_bin_id = i;
break;
}
}
s_num_input[0] = 0;
s_counter = 0;
}
__syncthreads();
int threshold_bin = s_threshold_bin_id;
int new_topk = topk - s_histogram[threshold_bin + 1];
// collect candidates
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin_id = static_cast<int>(convert_to_uint8(input[idx]));
if (bin_id > threshold_bin) {
int pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin_id == threshold_bin && new_topk > 0) {
int pos = ::atomicAdd(&s_num_input[0], 1);
if (pos < SMEM_INPUT_SIZE) {
[[likely]] s_input_idx[0][pos] = idx;
}
}
}
__syncthreads();
// stage 2: refine with 8bit radix passes
#pragma unroll 4
for (int round = 0; round < 4; ++round) {
if (new_topk <= 0)
break;
int r_idx = round % 2;
// reset
if (tx < RADIX + 1)
s_histogram[tx] = 0;
__syncthreads();
int num_input = s_num_input[r_idx];
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
int idx = s_input_idx[r_idx][i];
uint32_t bin32 =
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
::atomicAdd(&s_histogram[bin32], 1);
}
__syncthreads();
if (tx == 0) {
for (int i = RADIX - 2; i >= 0; --i)
s_histogram[i] += s_histogram[i + 1];
for (int i = 0; i < RADIX; i++) {
if (s_histogram[i] >= new_topk && s_histogram[i + 1] < new_topk) {
s_threshold_bin_id = i;
break;
}
}
s_num_input[r_idx ^ 1] = 0;
}
__syncthreads();
new_topk -= s_histogram[s_threshold_bin_id + 1];
int threshold_bin = s_threshold_bin_id;
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
int idx = s_input_idx[r_idx][i];
uint32_t bin32 =
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
if (bin32 > threshold_bin) {
int pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin32 == threshold_bin && new_topk > 0) {
if (round == 3) {
int pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else {
int pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
if (pos < SMEM_INPUT_SIZE)
s_input_idx[r_idx ^ 1][pos] = idx;
}
}
}
__syncthreads();
}
}
__global__ void topk_kernel(const FastTopKParams params) {
const auto &[input, indices, lengths, input_stride, use_tilelang] = params;
const auto bid = blockIdx.x;
const auto length = *(lengths + bid);
const auto indice = indices + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_cuda(score, indice, length);
} else {
if (use_tilelang) {
return fast_topk_cuda_tl(score, indice, length);
} else {
return fast_topk_cuda(score, indice, length);
}
}
}
__global__ void topk_kernel_transform_decode( // decode
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
const int32_t *__restrict__ src_page_table, const int64_t src_stride) {
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
const auto bid = blockIdx.x;
const auto tid = threadIdx.x;
const auto length = *(lengths + bid);
const auto src_page_entry = src_page_table + bid * src_stride;
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
if (use_tilelang) {
fast_topk_cuda_tl(score, s_indices, length);
} else {
fast_topk_cuda(score, s_indices, length);
}
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
__global__ void topk_kernel_transform_prefill( // prefill
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
const int32_t *__restrict__ src_page_table, const int64_t src_stride,
const int32_t *__restrict__ cu_seqlens, const int64_t prefill_bs) {
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
const auto bid = blockIdx.x;
const auto tid = threadIdx.x;
const auto length = *(lengths + bid);
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
/// NOTE: prefill bs is usually small, we can just use a simple loop here
/// We ensure that last cu_seqlens is equal to number of blocks launched
assert(gridDim.x == cu_seqlens[prefill_bs] &&
"Invalid cu_seqlens in topk-transform-prefill");
__shared__ const int32_t *s_src_page_entry;
if (tid == 0) {
for (int64_t offset = 0; offset < prefill_bs; ++offset) {
if (bid < cu_seqlens[offset + 1]) {
s_src_page_entry = src_page_table + offset * src_stride;
break;
}
}
}
__syncthreads();
const auto src_page_entry = s_src_page_entry;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
if (use_tilelang) {
fast_topk_cuda_tl(score, s_indices, length);
} else {
fast_topk_cuda(score, s_indices, length);
}
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
auto get_params(at::Tensor score, at::Tensor lengths, bool use_tilelang,
std::optional<at::Tensor> indices_opt = std::nullopt)
-> FastTopKParams {
const auto B = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
TORCH_CHECK(lengths.size(0) == B);
int32_t *indices_data_ptr = nullptr;
if (indices_opt.has_value()) {
const auto &indices = indices_opt.value();
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
TORCH_CHECK(indices.size(0) == B);
TORCH_CHECK(indices.size(1) == TopK);
indices_data_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.indices = indices_data_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
.use_tilelang = use_tilelang,
};
}
template <auto *f, size_t max_dynamic_smem>
auto setup_kernel_smem_once() -> void {
[[maybe_unused]]
static const auto result = [] {
return ::cudaFuncSetAttribute(
f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem);
}();
TORCH_CHECK(result == cudaSuccess,
"set_up_kernel_once failed:", ::cudaGetErrorString(result));
}
auto fast_topk_interface(at::Tensor score, at::Tensor indices,
at::Tensor lengths, bool use_tilelang) -> void {
const auto params = get_params(score, lengths, use_tilelang, indices);
const auto B = score.size(0);
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
setup_kernel_smem_once<topk_kernel, kSmem>();
topk_kernel<<<grid, block, kSmem, stream>>>(params);
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"topk kernel failed:", ::cudaGetErrorString(result));
}
auto fast_topk_transform_interface(at::Tensor score, at::Tensor lengths,
at::Tensor dst_page_table,
at::Tensor src_page_table,
at::Tensor cu_seqlens,
bool use_tilelang) -> void {
const auto params = get_params(score, lengths, use_tilelang);
const auto B = score.size(0);
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
TORCH_CHECK(cu_seqlens.dim() == 1 && cu_seqlens.is_contiguous());
const auto prefill_bs = cu_seqlens.size(0) - 1;
TORCH_CHECK(dst_page_table.size(0) == B);
TORCH_CHECK(dst_page_table.size(1) == TopK);
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
// launch kernel
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
const auto src_stride = src_page_table.stride(0);
// dispatch to decode or prefill
const auto is_decode = (prefill_bs == B);
if (is_decode) {
setup_kernel_smem_once<topk_kernel_transform_decode, kSmem>();
topk_kernel_transform_decode<<<grid, block, kSmem, stream>>>(
params, dst_page_table.data_ptr<int32_t>(),
src_page_table.data_ptr<int32_t>(), src_stride);
} else {
setup_kernel_smem_once<topk_kernel_transform_prefill, kSmem>();
topk_kernel_transform_prefill<<<grid, block, kSmem, stream>>>(
params, dst_page_table.data_ptr<int32_t>(),
src_page_table.data_ptr<int32_t>(), src_stride,
cu_seqlens.data_ptr<int32_t>(), prefill_bs);
}
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"topk kernel failed:", ::cudaGetErrorString(result));
}
} // namespace
PYBIND11_MODULE(topk_kernel, m) {
m.def("fast_topk", &fast_topk_interface);
m.def("fast_topk_transform", &fast_topk_transform_interface);
}
from __future__ import annotations
from typing import Any
import torch
from .utils import load_kernel_module
def _load_topk_module() -> Any:
"""
Load the index manipulation module.
"""
return load_kernel_module("topk.cu", "topk_kernel")
# TODO(dark): configure out why my cuda impl is a little slower....
# I believe it has something to do with unrolling loops (?)
_USE_TL = True
def fast_topk(
score: torch.Tensor,
indices: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk(score, indices, lengths, _USE_TL)
def fast_topk_transform(
score: torch.Tensor,
lengths: torch.Tensor,
dst_page_table: torch.Tensor,
src_page_table: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk_transform(
score, lengths, dst_page_table, src_page_table, cu_seqlens, _USE_TL
)
from __future__ import annotations
import os
from functools import lru_cache
from typing import Any, Iterable
@lru_cache()
def _prepare_for_load() -> str:
import os
import warnings
warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.utils.cpp_extension"
)
return os.path.dirname(os.path.abspath(__file__))
@lru_cache()
def load_kernel_module(
path: str | Iterable[str],
name: str,
*,
build: str = "build",
cflags: Iterable[str] | None = None,
cuda_flags: Iterable[str] | None = None,
ldflags: Iterable[str] | None = None,
) -> Any:
from torch.utils.cpp_extension import load
if isinstance(path, str):
path = (path,)
abs_path = _prepare_for_load()
build_dir = f"{abs_path}/{build}"
os.makedirs(build_dir, exist_ok=True)
return load(
name=name,
sources=[f"{abs_path}/csrc/{p}" for p in path],
extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
extra_ldflags=list(ldflags or []) or None,
build_directory=build_dir,
)
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
def dequantize_k_cache(quant_k_cache):
if NSA_DEQUANT_K_CACHE_FAST:
return _dequantize_k_cache_fast_wrapped(quant_k_cache)
else:
return _dequantize_k_cache_slow(quant_k_cache)
def _dequantize_k_cache_slow(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576,
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty(
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_nope * cur_scales
)
result = result.view(num_blocks, block_size, 1, d)
return result
def _dequantize_k_cache_fast_wrapped(
quant_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
assert dv == 512
assert dim_quant == 656
assert tile_size == 128
quant_k_cache = quant_k_cache.view((-1, dim_quant))
output = _dequantize_k_cache_fast(quant_k_cache)
return output.view(num_blocks, block_size, 1, -1)
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
num_tokens, dim_quant = quant_k_cache.shape
assert quant_k_cache.dtype == torch.float8_e4m3fn
dim_nope = 512
dim_rope = 64
num_tiles = dim_nope // group_size
assert dim_quant == 656
output = torch.empty(
(num_tokens, dim_nope + dim_rope),
dtype=torch.bfloat16,
device=quant_k_cache.device,
)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
input_nope_q = quant_k_cache[:, :dim_nope]
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
torch.float32
)
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
_dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output,
input_nope_q,
input_nope_s,
input_rope,
output.stride(0),
input_nope_q.stride(0),
input_nope_s.stride(0),
input_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
)
return output
@triton.jit
def _dequantize_k_cache_fast_kernel(
output_ptr,
input_nope_q_ptr,
input_nope_s_ptr,
input_rope_ptr,
output_stride_0: int,
input_nope_q_stride_0: int,
input_nope_s_stride_0: int,
input_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. dequant nope
effective_block_id = raw_block_id
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs_q < DIM_NOPE
ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(ptr_s)
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
tl.store(dst_ptr, y, mask=mask)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
raise Exception("UT is in quant_k_cache.py")
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
"""
k: data, 128 item per token, fp8
s: scale, 1 item per token, fp32
"""
class GetK:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
index_k_fp8 = torch.empty(
(seq_len_, pool.index_head_dim),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
return index_k_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim), uint8
"""
# can handle per 128B instead of per element
# page_indices: (num_pages,), element := a page index
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_page = pool.page_size * pool.index_head_dim
num_k_bytes_per_token = pool.index_head_dim
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
# flat_buf: (whatever,), uint8
flat_buf = buf.flatten()
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
num_k_bytes_per_page, dtype=torch.int32, device="cuda"
)[None, :]
flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 128)
class GetS:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
assert pool.index_head_dim // pool.quant_block_size == 1
index_k_scale_fp8 = torch.empty(
(seq_len_, 4),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][pool.page_size * pool.index_head_dim :].view(-1, 4)
return index_k_scale_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim // quant_block_size), uint8
"""
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
s_offset_in_page = pool.page_size * pool.index_head_dim
flat_buf = buf.flatten()
flat_indices = (
(page_indices * buf_numel_per_page)[:, None]
+ torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
None, :
]
+ s_offset_in_page
)
flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 4)
class SetK:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
buf[
page_index,
offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
] = index_k[i].view(torch.uint8)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_token = pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
+ torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
flat_indices = flat_indices.flatten()[:num_k_bytes_total]
flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
class SetS:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
start = pool.page_size * pool.index_head_dim
buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
index_k_scale[i].view(torch.uint8)
)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_token = 4
s_offset_in_page = pool.page_size * pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ s_offset_in_page
+ (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
+ torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
flat_indices = flat_indices.flatten()[:number_s_bytes_total]
flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
class SetKAndS:
@classmethod
def execute(cls, *args, buf, **kwargs):
if 0:
# print("SetK, SetS comparison test")
buf_cloned = buf.clone()
cls.vanilla(*args, **kwargs, buf=buf)
cls.triton(*args, **kwargs, buf=buf_cloned)
def _clear_token_0(target):
target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
_clear_token_0(buf)
_clear_token_0(buf_cloned)
assert torch.all(
buf == buf_cloned
), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
return
cls.triton(*args, **kwargs, buf=buf)
@classmethod
def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
@classmethod
def triton(cls, pool, buf, loc, index_k, index_k_scale):
_set_k_and_s_triton(
buf=buf,
loc=loc,
index_k=index_k,
index_k_scale=index_k_scale,
page_size=pool.page_size,
)
def _set_k_and_s_triton(
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
page_size: int,
):
"""
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
:param loc: (num_tokens_to_write,), int, element := the token index to write to
:param index_k: (num_tokens_to_write, 128 elem), fp8
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
:return:
"""
num_pages, buf_numel_per_page = buf.shape
(num_tokens_to_write,) = loc.shape
num_tokens_to_write_, index_head_dim = index_k.shape
num_tokens_to_write__, scale_dim = index_k_scale.shape
assert buf_numel_per_page == 64 * (128 + 4)
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
assert index_head_dim == 128
assert scale_dim == 1
assert page_size == 64
assert buf.dtype == torch.uint8
assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
assert index_k.dtype == torch.float8_e4m3fn
assert index_k_scale.dtype == torch.float32
assert buf.is_contiguous()
assert loc.is_contiguous()
assert index_k.is_contiguous()
assert index_k_scale.is_contiguous()
buf_fp8 = buf.view(torch.float8_e4m3fn)
buf_fp32 = buf.view(torch.float32)
_set_k_and_s_triton_kernel[(num_tokens_to_write,)](
buf_fp8,
buf_fp32,
loc,
index_k,
index_k_scale,
index_k.stride(0),
PAGE_SIZE=page_size,
BUF_NUMEL_PER_PAGE=buf_numel_per_page,
NUM_K_ELEMS_PER_TOKEN=index_head_dim,
S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
)
@triton.jit
def _set_k_and_s_triton_kernel(
buf_fp8_ptr,
buf_fp32_ptr,
loc_ptr,
index_k_ptr,
index_k_scale_ptr,
index_k_ptr_stride_0,
PAGE_SIZE: tl.constexpr,
BUF_NUMEL_PER_PAGE: tl.constexpr,
NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
):
token_id = tl.program_id(0)
loc = tl.load(loc_ptr + token_id)
in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
k = tl.load(index_k_ptr + in_k_offsets)
k_scale = tl.load(index_k_scale_ptr + token_id)
loc_page_index = loc // PAGE_SIZE
loc_token_offset_in_page = loc % PAGE_SIZE
out_k_offsets = (
loc_page_index * BUF_NUMEL_PER_PAGE
+ loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
+ tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
)
# "//4" b/c it is fp32 instead of uint8
out_s_offset = (
loc_page_index * BUF_NUMEL_PER_PAGE // 4
+ S_OFFSET_NBYTES_IN_PAGE // 4
+ loc_token_offset_in_page
)
tl.store(buf_fp8_ptr + out_k_offsets, k)
tl.store(buf_fp32_ptr + out_s_offset, k_scale)
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.utils import add_prefix, is_npu
if not is_npu():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
import deep_gemm
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import add_prefix, align, is_cuda
try:
import deep_gemm_v32
except ImportError as e:
print("Error when importing deep_gemm_v32, try deep_gemm")
try:
import deep_gemm as deep_gemm_v32
except ImportError as e:
print("Error when importing deep_gemm, skip")
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
class BaseIndexerMetadata(ABC):
@abstractmethod
def get_seqlens_int32(self) -> torch.Tensor:
"""
Return: (batch_size,) int32 tensor
"""
@abstractmethod
def get_page_table_64(self) -> torch.Tensor:
"""
Return: (batch_size, num_blocks) int32, page table.
The page size of the table is 64.
"""
@abstractmethod
def get_seqlens_expanded(self) -> torch.Tensor:
"""
Return: (sum_extend_seq_len,) int32 tensor
"""
@abstractmethod
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""
Perform topk selection on the logits and possibly transform the result.
NOTE that attention backend may override this function to do some
transformation, which means the result of this topk_transform may not
be the topk indices of the input logits.
Return: Anything, since it will be passed to the attention backend
for further processing on sparse attention computation.
Don't assume it is the topk indices of the input logits.
"""
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
from fast_hadamard_transform import hadamard_transform
hidden_size = x.size(-1)
assert (
hidden_size & (hidden_size - 1)
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
return hadamard_transform(x, scale=hidden_size**-0.5)
class V32LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(
x.float(), (self.dim,), self.weight, self.bias, self.eps
).type_as(x)
class Indexer(CustomOp):
def __init__(
self,
hidden_size: int,
index_n_heads: int,
index_head_dim: int,
rope_head_dim: int,
index_topk: int,
q_lora_rank: int,
max_position_embeddings: int,
rope_theta: float,
layer_id: int,
scale_fmt: Optional[str],
block_size: int = 128,
rope_scaling: Optional[Dict[str, Any]] = None,
prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = index_n_heads
self.head_dim = index_head_dim
self.rope_head_dim = rope_head_dim
self.index_topk = index_topk
self.q_lora_rank = q_lora_rank
self.layer_id = layer_id
self.alt_stream = alt_stream
if not is_npu():
self.sm_count = deep_gemm.get_num_sms()
self.half_device_sm_count = align(self.sm_count // 2, 8)
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wq_b", prefix),
)
self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
self.k_norm = V32LayerNorm(self.head_dim)
# NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
prefix=add_prefix("weights_proj", prefix),
)
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
rotary_dim=rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
self.block_size = block_size
self.scale_fmt = scale_fmt
self.softmax_scale = self.head_dim**-0.5
def _forward_fake(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
):
bs = x.shape[0]
assert self.index_topk == 2048
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
None, ...
].repeat(bs, 1)
if forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.seq_lens_cpu is not None
)
which = 0
for i, (kv_len, qo_len) in enumerate(
zip(
forward_batch.seq_lens_cpu.tolist(),
forward_batch.extend_seq_lens_cpu,
strict=True,
)
):
for j in range(kv_len - qo_len, kv_len):
ans[which, j + 1 :] = -1
which += 1
assert which == ans.shape[0]
else:
assert forward_batch.seq_lens_cpu is not None
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
ans[i, seq_len:] = -1
return ans
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
def _get_q_k_bf16(
self,
q_lora: torch.Tensor,
x: torch.Tensor,
positions: torch.Tensor,
enable_dual_stream: bool,
):
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
self.half_device_sm_count
):
query, _ = self.wq_b(q_lora)
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
with torch.cuda.stream(self.alt_stream):
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
current_stream.wait_stream(self.alt_stream)
else:
query, _ = self.wq_b(q_lora)
if dumper._enable:
after_wq_b = query.clone()
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
key, _ = self.wk(x)
if dumper._enable:
after_wk = key.clone()
key = self.k_norm(key)
if dumper._enable:
after_k_norm = key.clone()
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
query[..., : self.rope_head_dim] = q_rope
key[..., : self.rope_head_dim] = k_rope
if dumper._enable:
q_before_hadamard = query.clone()
k_before_hadamard = key.clone()
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
query = rotate_activation(query)
with torch.cuda.stream(self.alt_stream):
key = rotate_activation(key)
current_stream.wait_stream(self.alt_stream)
else:
query = rotate_activation(query)
key = rotate_activation(key)
return query, key
def _get_topk_paged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm_v32
assert page_size == 64, "only support page size 64"
# NOTE(dark): this support extend/decode/decode+graph
block_tables = metadata.get_page_table_64()
max_seq_len = block_tables.shape[1] * page_size
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
layer_id=layer_id
)
blocksize = page_size
seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
seqlens_32, blocksize, self.sm_count
)
assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
assert len(kv_cache_fp8.shape) == 2
block_kv = 64
num_heads_kv = 1
head_dim_with_sf = 132
kv_cache_fp8 = kv_cache_fp8.view(
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
)
assert len(weights.shape) == 3
weights = weights.squeeze(2)
logits = deep_gemm_v32.fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
seqlens_32,
block_tables,
schedule_metadata,
max_seq_len,
clean_logits=False,
)
# NOTE(dark): logits should be cleaned in topk_transform
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def _get_topk_ragged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
assert page_size == 64, "only support page size 64"
assert len(weights.shape) == 3
weights = weights.squeeze(-1)
k_fp8_list = []
k_scale_list = []
ks_list = []
offset = 0
block_tables = metadata.get_page_table_64()
assert (
forward_batch.seq_lens_cpu is not None
and forward_batch.extend_seq_lens_cpu is not None
)
for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens_cpu[i].item()
assert isinstance(seq_len, int)
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
layer_id,
seq_len,
block_tables[i],
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
offset += extend_seq_len
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded
logits = deep_gemm_v32.fp8_mqa_logits(
q_fp8,
kv_fp8,
weights,
ks,
ke,
clean_logits=False,
)
assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def _forward(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
metadata = forward_batch.attn_backend.get_indexer_metadata(
layer_id, forward_batch
)
enable_dual_stream = (
NSA_DUAL_STREAM
and self.alt_stream is not None
and get_is_capture_mode()
and q_lora.shape[0] > 0
and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
)
# skip NSA if attention backend choose to skip this batch
if metadata is None:
return None
if not NSA_USE_REAL_INDEXER: # temporary
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
q_fp8 = query.to(torch.float32)
k_fp8 = key.to(torch.float32)
q_scale = torch.ones((query.shape[0], 1), dtype=torch.float32, device="cuda")
k_scale = torch.ones((key.shape[0], 1), dtype=torch.float32, device="cuda")
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
with torch.cuda.stream(self.alt_stream):
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
current_stream.wait_stream(self.alt_stream)
else:
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
weights = self._get_logits_head_gate(x, q_scale)
assert forward_batch.seq_lens_cpu is not None
if len(forward_batch.seq_lens_cpu) == 0:
# this seems b/c max-pad, no worries?
# if x.shape[0] != 0:
# print(
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
# )
return torch.full(
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
)
if forward_batch.forward_mode.is_decode_or_idle():
topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata
)
else:
topk_result = self._get_topk_ragged(
forward_batch, layer_id, q_fp8, weights, metadata
)
return topk_result
def forward_cuda(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
return self._forward(x, q_lora, positions, forward_batch, layer_id)
def forward_npu(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> torch.Tensor:
import custom_ops
import torch_npu
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.utils import get_bool_env_var
if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = (
forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
)
enable_index_cp = (
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
)
is_prefill = forward_batch.forward_mode.is_extend()
attention_tp_rank = get_attention_tp_rank()
attention_tp_size = get_attention_tp_size()
cos_sin = self.rotary_emb.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
if is_prefill and enable_index_cp:
slice_length = cos.shape[0] // attention_tp_size
cos = cos[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
sin = sin[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
slot_mapping = forward_batch.out_cache_loc
block_table = forward_batch.attn_backend.forward_metadata.block_tables
bs = x.shape[0]
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
q_pe, q_nope = torch.split(
q,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64, 64 + 64]
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
bs, self.n_heads, self.rope_head_dim
) # [bs, n, d]
q = torch.cat([q_pe, q_nope], dim=-1)
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
k = self.k_norm(k_proj)
k_pe, k_nope = torch.split(
k,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64 + 64]
k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
bs, 1, self.rope_head_dim
) # [bs, 1, d]
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
if is_prefill and enable_index_cp:
k, local_k = (
torch.empty(
(k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
dtype=k.dtype,
device=k.device,
),
k,
)
get_attention_tp_group().all_gather_into_tensor(k, local_k)
forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
indexer_input = {}
if is_prefill:
actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
device=q.device
)
if enable_index_cp:
actual_seq_lengths_q -= bs * attention_tp_rank
actual_seq_lengths_q = torch.max(
actual_seq_lengths_q,
torch.zeros_like(actual_seq_lengths_q).to(
device=actual_seq_lengths_q.device
),
)
actual_seq_lengths_q = torch.min(
actual_seq_lengths_q,
torch.full(actual_seq_lengths_q.shape, bs).to(
device=actual_seq_lengths_q.device
),
)
else:
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
)
else:
actual_seq_lengths_q = (
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
)
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x)[0]
block_table = (
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
)
topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q.view(-1, self.n_heads, self.head_dim),
key=past_key_states,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=self.index_topk,
sparse_mode=3,
)
if is_prefill and enable_index_cp:
topk_indices, local_topk_indices = (
torch.empty(
(
topk_indices.shape[0] * attention_tp_size,
topk_indices.shape[1],
topk_indices.shape[2],
),
dtype=topk_indices.dtype,
device=topk_indices.device,
),
topk_indices,
)
get_attention_tp_group().all_gather_into_tensor(
topk_indices, local_topk_indices
)
return topk_indices
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
def quantize_k_cache(cache_k):
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
if NSA_QUANT_K_CACHE_FAST:
return _quantize_k_cache_fast_wrapped(cache_k)
else:
return _quantize_k_cache_slow(cache_k)
# Copied from original
def _quantize_k_cache_slow(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty(
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
dtype=torch.float8_e4m3fn,
device=input_k_cache.device,
)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = (
torch.abs(
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
)
.max(dim=-1)
.values
/ 448.0
) # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (
input_k_cache[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].float()
/ cur_scale_factors_inv.float()
).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_quantized_nope
)
result = result.view(num_blocks, block_size, 1, -1)
return result
def _quantize_k_cache_fast_wrapped(
input_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
assert dv == 512
assert dim_nope_and_rope == 512 + 64
assert tile_size == 128
input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
k_nope = input_k_cache[:, :dv]
k_rope = input_k_cache[:, dv:]
output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
return output.view(num_blocks, block_size, 1, -1)
def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
"""
:param k_nope: (num_tokens, dim_nope 512)
:param k_rope: (num_tokens, dim_rope 64)
"""
assert k_nope.dtype == torch.bfloat16
assert k_rope.dtype == torch.bfloat16
num_tokens, dim_nope = k_nope.shape
num_tokens_, dim_rope = k_rope.shape
assert num_tokens == num_tokens_
assert dim_nope == 512
assert dim_rope == 64
assert k_nope.dtype == k_rope.dtype
num_tiles = dim_nope // group_size
assert k_nope.stride(1) == 1
assert k_rope.stride(1) == 1
output = torch.empty(
(num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
dtype=torch.float8_e4m3fn,
device=k_nope.device,
)
output_nope_q = output[..., :dim_nope]
output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
_quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output_nope_q,
output_nope_s,
output_rope,
k_nope,
k_rope,
output_nope_q.stride(0),
output_nope_s.stride(0),
output_rope.stride(0),
k_nope.stride(0),
k_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
)
return output
@triton.jit
def _quantize_k_cache_fast_kernel(
output_nope_q_ptr,
output_nope_s_ptr,
output_rope_ptr,
k_nope_ptr,
k_rope_ptr,
output_nope_q_stride_0: int,
output_nope_s_stride_0: int,
output_rope_stride_0: int,
k_nope_stride_0: int,
k_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
FP8_MIN: tl.constexpr,
FP8_MAX: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. quant nope
effective_block_id = raw_block_id
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_NOPE
ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
y_s = tl.max(tl.abs(y)) / FP8_MAX
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
output_nope_q_ptr.dtype.element_ty
)
dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
dst_s_ptr = (
output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
)
tl.store(dst_q_ptr, y_q, mask=mask)
tl.store(dst_s_ptr, y_s)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
data = tl.load(src_ptr, mask=mask)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
for num_blocks, block_size in [
(1, 1),
(10, 64),
]:
dim_nope_and_rope = 512 + 64
input_k_cache = torch.randn(
(num_blocks, block_size, 1, dim_nope_and_rope),
dtype=torch.bfloat16,
device="cuda",
)
# temp debug
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
ref_quant = _quantize_k_cache_slow(input_k_cache)
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
# print(f"{input_k_cache=}")
# print(f"{ref_quant=}")
# print(f"{actual_quant=}")
# print(f"{ref_quant == actual_quant=}")
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
# print(f"{ref_quant.view(torch.bfloat16)=}")
# print(f"{actual_quant.view(torch.bfloat16)=}")
# assert torch.all(ref_quant == actual_quant)
import dequant_k_cache
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
actual_quant
)
print(f"{ref_ref_dequant=}")
print(f"{actual_actual_dequant=}")
print(f"{actual_actual_dequant - ref_ref_dequant=}")
print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
# TODO too different?
torch.testing.assert_close(
ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
)
torch.testing.assert_close(
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
)
print("Passed")
from typing import Optional, Tuple
import tilelang
import tilelang.language as T
import torch
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
}
BF16 = "bfloat16"
FP8 = "float8_e4m3"
FP32 = "float32"
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
):
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale else 2
blk_m = 32
group_size = 128
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
s_local = T.alloc_fragment((blk_m,), scale_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert (
x.size(-1) % block_size == 0
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
return y, s
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
blk_n1 = 512
blk_n2 = 128
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
q_s_frag = T.alloc_fragment(h, FP32)
T.copy(q_s[i_b, i_m, 0], q_s_frag)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
logits = T.alloc_fragment((blk_n2, h), FP32)
T.gemm(
k_smem,
q_smem,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
logits_sum = T.alloc_fragment(blk_n2, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
return fp8_index_kernel_
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_attention_fwd_kernel_v1(
num_heads,
dim,
tail_dim,
topk,
*,
kv_group=1,
sm_scale=None,
is_causal=True,
block_I=64,
num_stages=2,
threads=256,
):
assert dim == tilelang.math.next_power_of_2(
dim
), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim
), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert (
topk % block_I == 0
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = num_heads // kv_group
q_shape = [batch, seq_len, num_heads, dim + tail_dim]
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
o_shape = [batch, seq_len, num_heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
bx,
by,
bz,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
O_shared = T.alloc_shared([H_per_block, D], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i
]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i
]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
mask[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, O_shared)
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
return main
@tilelang.jit(
out_idx=[-1],
compile_flags=[
"-O3",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10",
"-DNDEBUG",
],
) # type: ignore
def sparse_attention_fwd_kernel_v2(
num_heads: int,
dim: int,
tail_dim: int,
topk: int,
*,
kv_group: int = 1,
sm_scale: Optional[float] = None,
block_I: int = 64,
):
assert dim == tilelang.math.next_power_of_2(
dim
), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim
), f"haven't check padding correctness yet, dim={tail_dim}"
assert (
topk % block_I == 0
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
threads = 384
batch = T.symbolic("batch")
qo_len = T.symbolic("seq_len")
num_pages = T.symbolic("num_pages")
q_shape = [batch, qo_len, num_heads, dim + tail_dim]
kv_shape = [batch, num_pages, kv_group, dim + tail_dim]
o_shape = [batch, qo_len, num_heads, dim]
indices_shape = [batch, qo_len, kv_group, topk]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
H = num_heads
padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)
if padded_H != H:
assert kv_group == 1
BI = block_I
NI = tilelang.cdiv(topk, block_I)
assert NI % 2 == 0, "NI should be a multiple of 2"
D = dim
D_tail = tail_dim
if num_heads > 64:
assert num_heads % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = num_heads // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
):
"""
Q: [b, qo_len, H, D + D_tail] (bfloat16)
KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
Indices: [b, qo_len, kv_group, topk] (int32)
"""
with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
O_shared_l = Q_shared_l
O_shared_r = Q_shared_r
is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared")
is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared")
acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
indices_local = T.alloc_local([1], indices_dtype)
indices_tmp = T.alloc_local([1], indices_dtype)
bar_q = T.alloc_barrier(arrive_count=384)
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
bar_k_0_free = T.alloc_barrier(arrive_count=256)
bar_k_1_free = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
bar_0_128 = T.alloc_barrier(arrive_count=128)
bar_1_128 = T.alloc_barrier(arrive_count=128)
bar_2_128 = T.alloc_barrier(arrive_count=128)
bar_final = T.alloc_barrier(arrive_count=128)
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
tx = T.get_thread_binding()
T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
# with sync_at(bar_0_128, 0):
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
T.barrier_arrive(bar_0_128)
T.barrier_wait(bar_0_128, 0)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_tail_shared,
K_tail_shared_0,
acc_s,
transpose_B=True,
wg_wait=-1,
)
T.wait_wgmma(0)
if i_i != 0:
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(
acc_s, sumexp_i, dim=1
) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_0_free[0])
# Buffer 1
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
T.barrier_arrive(bar_0_128)
T.barrier_wait(bar_0_128, 1)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_tail_shared,
K_tail_shared_1,
acc_s,
transpose_B=True,
wg_wait=-1,
)
T.wait_wgmma(0)
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(
acc_s, sumexp_i, dim=1
) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_1_free[0])
# Rescale
for h_i in T.Parallel(H_per_block):
sum_exp_shared[h_i] = sumexp[h_i]
T.barrier_arrive(bar_final)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
elif tx >= 128 and tx < 256:
# T.set_max_nreg(168, 1)
T.fill(acc_o_r, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
T.barrier_arrive(bar_1_128)
T.barrier_wait(bar_1_128, 0)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
T.barrier_arrive(bar_k_0_free[0])
T.barrier_arrive(bar_sScale_and_sS_free)
# Buffer 1
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
T.barrier_arrive(bar_1_128)
T.barrier_wait(bar_1_128, 1)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
T.barrier_arrive(bar_k_1_free[0])
if i_i != T.ceildiv(NI, 2) - 1:
T.barrier_arrive(bar_sScale_and_sS_free)
# Rescale
T.barrier_wait(bar_final, 0)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
elif tx >= 256:
# producer
T.set_max_nreg(80, 0)
indices_local[0] = 0
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
T.barrier_arrive(bar_2_128)
T.barrier_wait(bar_2_128, 0)
for r in T.serial(4):
indices_tmp[0] = Indices[
b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8
]
is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
if is_kv_valid_0[r * 16 + (tx - 256) // 8]:
indices_local[0] = indices_tmp[0]
with T.attr("default", "async_scope", 1): # type: ignore
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
64 * u + (tx - 256) % 8 * 8 + v,
]
KV_shared_0_r[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
]
with T.attr("default", "async_scope", 1): # type: ignore
for v in T.vectorized(8):
K_tail_shared_0[
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
] = KV[
b_i,
indices_local[0],
g_i,
D + (tx - 256) % 8 * 8 + v,
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
T.barrier_arrive(bar_2_128)
T.barrier_wait(bar_2_128, 1)
for r in T.serial(4):
indices_tmp[0] = Indices[
b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8
]
is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
if is_kv_valid_1[r * 16 + (tx - 256) // 8]:
indices_local[0] = indices_tmp[0]
with T.attr("default", "async_scope", 1): # type: ignore
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
64 * u + (tx - 256) % 8 * 8 + v,
]
KV_shared_1_r[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
]
with T.attr("default", "async_scope", 1): # type: ignore
for v in T.vectorized(8):
K_tail_shared_1[
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
] = KV[
b_i,
indices_local[0],
g_i,
D + (tx - 256) % 8 * 8 + v,
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
return main
def tilelang_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> torch.Tensor:
assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3
num_heads = q.shape[1]
dim = q.shape[2]
tail_dim = dim - d_v
topk = indices.shape[-1]
assert topk == 2048
# NOTE(dark): v2 offers better performance than v1
kernel = sparse_attention_fwd_kernel_v2(
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
)
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
import torch
from sglang.srt.utils import align
# NOTE(dark): flashmla P requires `params.topk % (2*B_TOPK) == 0`,
# where `B_TOPK=64`. So we align to 128 by default.
_TOPK_ALIGNMENT = 128
# TODO(dark): maybe this torch_op can support torch.compile
def _fast_topk_torch(
input: torch.Tensor, seq_lens: torch.Tensor, topk: int, alignment: int
) -> torch.Tensor:
# Fallback to torch.topk
bs, max_seq_len = input.shape
assert len(seq_lens) == bs
# set those out-of-bound input to -inf
padded_max_seq_len = align(max_seq_len, alignment)
positions = torch.arange(
padded_max_seq_len, device=input.device, dtype=seq_lens.dtype
)
positions = positions.unsqueeze(0).expand(bs, -1)
mask = positions >= seq_lens.unsqueeze(1)
# NOTE(dark): just return all valid indices as an optimization
if padded_max_seq_len <= topk:
return positions.masked_fill(mask, -1)
assert topk % alignment == 0
# in-place operation: mask invalid inputs to -inf
input = input.masked_fill_(mask[:, :max_seq_len], float("-inf"))
result = input.topk(topk, dim=-1, sorted=True)
return result.indices.masked_fill_(mask[:, :topk], -1)
def fast_topk_impl(
input: torch.Tensor,
seq_lens: torch.Tensor,
topk: int,
alignment: int = _TOPK_ALIGNMENT,
) -> torch.Tensor:
return _fast_topk_torch(input, seq_lens, topk, alignment)
def fast_topk_transform_fused_cuda(
input: torch.Tensor,
seq_lens: torch.Tensor,
topk: int,
dst_page_table: torch.Tensor,
src_page_table: torch.Tensor,
cu_seqlens_q: torch.Tensor,
alignment: int = _TOPK_ALIGNMENT,
) -> torch.Tensor:
from sglang.srt.layers.attention.nsa.cuda import fast_topk_transform
assert topk == 2048 and topk % alignment == 0
return fast_topk_transform(
score=input,
lengths=seq_lens,
dst_page_table=dst_page_table,
src_page_table=src_page_table,
cu_seqlens=cu_seqlens_q,
)
from typing import List, Optional
import torch
import triton
import triton.language as tl
def transform_index_page_table_prefill(**kwargs):
return transform_index_page_table_prefill_ref(**kwargs)
def transform_index_page_table_decode(**kwargs):
return transform_index_page_table_decode_ref(**kwargs)
@triton.jit
def transform_index_page_table_decode_kernel(
page_table_ptr: torch.Tensor,
topk_indices_ptr: torch.Tensor,
result_ptr: torch.Tensor,
page_size: tl.constexpr,
max_seqlen_k: tl.constexpr,
):
TOPK: tl.constexpr = 2048
req_id = tl.program_id(0)
page_table_ptr = page_table_ptr + req_id * max_seqlen_k
topk_indices_ptr = topk_indices_ptr + req_id * TOPK
result_ptr = result_ptr + req_id * TOPK
offset = tl.arange(0, TOPK) # topk should be 2048
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
mask = loaded_topk_indices >= 0
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
tl.store(result_ptr + offset, -1, mask=~mask)
def transform_index_page_table_decode_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
"""
Transform the page table according to topk indices for sparse topk attention.
Args:
page_table: [qo_len, max_seqlen_k], the original page table
topk_indices: [qo_len, topk], the topk indices for each query position
Returns:
transformed_page_table: [qo_len, topk], the transformed page table
For out-of-bound indices in topk_indices, this should be filled with -1.
"""
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
assert topk_indices.shape[1] == 2048
qo_len = topk_indices.shape[0]
max_seqlen_k = page_table.shape[1]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
# Launch triton kernel
grid = (qo_len,)
transform_index_page_table_decode_kernel[grid](
page_table,
topk_indices,
result,
page_size,
max_seqlen_k=max_seqlen_k,
)
return result
def transform_index_page_table_prefill_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
# TODO(baizhou): can be implemented with another triton kernel
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_fast(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
def transform_index_page_table_decode_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert result.shape == topk_indices.shape
torch.gather(
page_table,
dim=1,
index=topk_indices.clamp(min=0),
out=result,
)
result[topk_indices < 0] = -1
return result
def transform_index_page_table_prefill_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_ref(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
if __name__ == "__main__":
bs, topk, max_seqlen = 10, 2048, 3000
page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
topk_indices = torch.full((bs, topk), -1, device="cuda")
topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
result = transform_index_page_table_decode_fast(page_table, topk_indices)
assert torch.all(result == ref_result)
print("Passed")
import torch
import torch.nn as nn
class DummyModel(nn.Module):
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
super().__init__()
self.weights_proj = nn.Linear(d_in, 1024)
self.n_heads = n_heads
self.softmax_scale = softmax_scale
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
weights = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
q_scale = q_scale.unsqueeze(1) # (B,1,1)
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
weights = self.weights_proj(x)
q_scale = q_scale.unsqueeze(1) # (B,1,1)
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
return weights
def main():
torch.manual_seed(0)
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
x = torch.randn(128, 2048) # batch=128, d_in=2048
q_scale = torch.randn(128, 1)
import time
start = time.time()
for _ in range(1000):
out_orig = model._get_logits_head_gate_orig(x, q_scale)
print("Original version time:", time.time() - start)
start = time.time()
for _ in range(1000):
out_opt = model._get_logits_head_gate_opt(x, q_scale)
print("Optimized version time:", time.time() - start)
print("Difference:", (out_orig - out_opt).abs().max().item())
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
if __name__ == "__main__":
main()
"""
Original version time: 0.49235057830810547
Optimized version time: 0.4087331295013428
Difference: 1.4901161193847656e-08
"""
# temp NSA debugging environ
from sglang.srt.utils import get_bool_env_var
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
)
NSA_KV_CACHE_STORE_FP8 = get_bool_env_var("SGLANG_NSA_KV_CACHE_STORE_FP8", "false")
NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "false")
NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "false")
def _print_bool_env_vars():
msg = ""
for k, v in globals().items():
if k.startswith("NSA_") and isinstance(v, bool):
msg += f"{k}={v} "
print(msg, flush=True)
_print_bool_env_vars()
if not NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8:
assert not NSA_KV_CACHE_STORE_FP8
def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
return original_seq_lens.clamp(max=nsa_index_topk)
from __future__ import annotations
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Dict,
List,
Literal,
Optional,
Tuple,
TypeAlias,
Union,
override,
)
import torch
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.layers.attention.nsa.topk import (
fast_topk_impl,
fast_topk_transform_fused_cuda,
)
from sglang.srt.layers.attention.nsa.transform_index import (
transform_index_page_table_decode,
transform_index_page_table_prefill,
)
from sglang.srt.layers.attention.nsa.utils import (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
NSA_FUSE_TOPK,
NSA_KV_CACHE_STORE_FP8,
compute_nsa_seqlens,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.two_batch_overlap import global_server_args_dict
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
@dataclass(frozen=True)
class NSAFlashMLAMetadata:
"""Metadata only needed by FlashMLA"""
flashmla_metadata: torch.Tensor
num_splits: torch.Tensor
def slice(self, sli):
return NSAFlashMLAMetadata(
flashmla_metadata=self.flashmla_metadata,
num_splits=self.num_splits[sli],
)
def copy_(self, other: "NSAFlashMLAMetadata"):
self.flashmla_metadata.copy_(other.flashmla_metadata)
self.num_splits.copy_(other.num_splits)
@dataclass(frozen=True)
class NSAMetadata:
page_size: int
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor
# Maximum sequence length for query
max_seq_len_q: int
# Maximum sequence length for key
max_seq_len_k: int
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor
# Page table, the index of KV Cache Tables/Blocks
# this table is always with page_size = 1
page_table_1: torch.Tensor
# NOTE(dark): This will property be used in:
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
real_page_table: torch.Tensor
# NSA metadata (nsa prefill are expanded)
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
nsa_extend_seq_lens_list: List[int]
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
@dataclass(frozen=True)
class NSAIndexerMetadata(BaseIndexerMetadata):
attn_metadata: NSAMetadata
@override
def get_seqlens_int32(self) -> torch.Tensor:
return self.attn_metadata.cache_seqlens_int32
@override
def get_page_table_64(self) -> torch.Tensor:
return self.attn_metadata.real_page_table
@override
def get_seqlens_expanded(self) -> torch.Tensor:
return self.attn_metadata.nsa_seqlens_expanded
@override
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
if not NSA_FUSE_TOPK:
return fast_topk_impl(logits, self.get_seqlens_expanded(), topk)
# NOTE(dark): if fused, we return a transformed page table directly
dst_page_table = torch.empty(
(logits.shape[0], topk), dtype=torch.int32, device=logits.device
)
fast_topk_transform_fused_cuda(
input=logits,
seq_lens=self.get_seqlens_expanded(),
topk=topk,
dst_page_table=dst_page_table,
src_page_table=self.attn_metadata.page_table_1,
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
)
return dst_page_table
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
return torch.nn.functional.pad(
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
)
_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
]
NSA_PREFILL_IMPL: _NSA_IMPL_T
NSA_DECODE_IMPL: _NSA_IMPL_T
class NativeSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata: NSAMetadata
self.device = model_runner.device
assert isinstance(model_runner.page_size, int)
self.real_page_size = model_runner.page_size
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
self.max_context_len = model_runner.model_config.context_len
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
assert model_runner.req_to_token_pool is not None
self.req_to_token = model_runner.req_to_token_pool.req_to_token
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
def get_device_int32_arange(self, l: int) -> torch.Tensor:
if l > len(self._arange_buf):
next_pow_of_2 = 1 << (l - 1).bit_length()
self._arange_buf = torch.arange(
next_pow_of_2, device=self.device, dtype=torch.int32
)
return self._arange_buf[:l]
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
page_size = self.real_page_size
if page_size == 1:
return page_table
max_seqlen_k = page_table.shape[1]
strided_indices = torch.arange(
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
)
return page_table[:, strided_indices] // page_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
batch_size = forward_batch.batch_size
device = forward_batch.seq_lens.device
assert (
forward_batch.spec_info is None
), "Spec decoding is not supported for NSA backend now"
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
assert forward_batch.seq_lens_cpu is not None
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :max_seqlen_k
]
if forward_batch.forward_mode.is_decode_or_idle():
extend_seq_lens_cpu = [1] * batch_size
max_seqlen_q = 1
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
seqlens_expanded = cache_seqlens_int32
elif forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.extend_seq_lens is not None
and forward_batch.extend_prefix_lens_cpu is not None
), "All of them must not be None"
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
assert forward_batch.extend_seq_lens is not None
if any(forward_batch.extend_prefix_lens_cpu):
max_seqlen_q = max(extend_seq_lens_cpu)
cu_seqlens_q = compute_cu_seqlens(
forward_batch.extend_seq_lens.to(torch.int32)
)
else:
max_seqlen_q = max_seqlen_k
cu_seqlens_q = cu_seqlens_k
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=device,
)
for qo_len, kv_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.seq_lens_cpu.tolist(),
strict=True,
)
]
)
else:
assert False, f"Unsupported {forward_batch.forward_mode = }"
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
original_seq_lens=seqlens_expanded,
nsa_index_topk=self.nsa_index_topk,
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table,
flashmla_metadata=(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=seqlens_expanded,
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
real_page_table=self._transform_table_1_to_real(page_table),
)
self.forward_metadata = metadata
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self.decode_cuda_graph_metadata: Dict = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
# fake page_table for sparse_prefill
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"flashmla_metadata": (
self._compute_flashmla_metadata(
cache_seqlens=torch.ones(
max_bs, dtype=torch.int32, device=self.device
),
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Initialize forward metadata for capturing CUDA graph."""
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
# Normal Decode
# Get sequence information
cache_seqlens_int32 = seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
# Use max context length for seq_len_k
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
max_seq_len_k = page_table_1.shape[1]
# Precompute page table
# Precompute cumulative sequence lengths
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
real_page_table = self._transform_table_1_to_real(page_table_1)
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs + 1))
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
else:
flashmla_metadata = None
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=1,
max_seq_len_k=max_seq_len_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table_1,
flashmla_metadata=flashmla_metadata,
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=cache_seqlens_int32,
real_page_table=real_page_table,
nsa_extend_seq_lens_list=[1] * bs,
)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert seq_lens_cpu is not None
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
# Normal Decode
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
max_len = int(seq_lens_cpu.max().item())
cache_seqlens = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_len]
metadata.page_table_1[:, :max_len].copy_(page_indices)
assert (
metadata.nsa_cache_seqlens_int32 is not None
and metadata.nsa_cu_seqlens_k is not None
and self.nsa_index_topk is not None
)
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
metadata.nsa_cu_seqlens_k[1:].copy_(
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
)
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
assert self.real_page_size == metadata.page_size
if self.real_page_size > 1:
real_table = self._transform_table_1_to_real(page_indices)
new_len = real_table.shape[1]
metadata.real_page_table[:, :new_len].copy_(real_table)
else:
assert metadata.real_page_table is metadata.page_table_1
if NSA_DECODE_IMPL == "flashmla_decode":
metadata.flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
self.forward_metadata = metadata
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert (
not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
), "NSA backend doesn't support speculative decoding"
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
# Do absorbed multi-latent attention
assert q_rope is not None
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# when store in fp8 and compute in fp8, no need to convert dtype
if not (NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and NSA_KV_CACHE_STORE_FP8):
kv_cache = kv_cache.to(q.dtype)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
# NOTE(dark): here, we use page size = 1
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
assert metadata.nsa_extend_seq_lens_list is not None
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
if NSA_PREFILL_IMPL == "tilelang":
from sglang.srt.layers.attention.nsa.tilelang_kernel import (
tilelang_sparse_fwd,
)
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
forward_batch=forward_batch,
metadata=metadata,
topk_indices=topk_indices,
block_table=metadata.real_page_table,
)
elif NSA_PREFILL_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
else:
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
page_size=1,
)
if NSA_DECODE_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
forward_batch=forward_batch,
metadata=metadata,
topk_indices=topk_indices,
block_table=metadata.real_page_table,
)
elif NSA_DECODE_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
else:
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
def _forward_fa3(
self,
q_rope: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
q_nope: torch.Tensor,
page_table: torch.Tensor,
cache_seqlens: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
sm_scale: float,
logit_cap: float,
page_size: int,
) -> torch.Tensor:
k_rope_cache = kv_cache[:, :, v_head_dim:]
c_kv_cache = kv_cache[:, :, :v_head_dim]
qk_rope_dim = k_rope_cache.shape[-1]
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=sm_scale,
causal=True,
softcap=logit_cap,
return_softmax_lse=False,
num_splits=self.num_splits,
)
return o # type: ignore
def _forward_flashmla_prefill(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
from flash_mla import flash_mla_sparse_fwd
o, _, _ = flash_mla_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
return o
def _forward_flashmla_decode(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
sm_scale: float,
layer,
forward_batch: ForwardBatch,
metadata: NSAMetadata,
topk_indices,
block_table,
) -> torch.Tensor:
from flash_mla import flash_mla_with_kvcache
cache_seqlens = metadata.nsa_cache_seqlens_int32
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
assert self.real_page_size == 64, "only page size 64 is supported"
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not NSA_KV_CACHE_STORE_FP8:
# inefficiently quantize the whole cache
kv_cache = quantize_k_cache(kv_cache)
o, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv_cache,
cache_seqlens=cache_seqlens,
head_dim_v=v_head_dim,
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
num_splits=metadata.flashmla_metadata.num_splits,
softmax_scale=sm_scale,
# TODO improve
indices=_compute_indices_in_kvcache(
block_table=block_table,
topk_indices=topk_indices.to(torch.int32),
page_size=self.real_page_size,
nsa_index_topk=self.nsa_index_topk,
),
# doc says it is not used, but if pass in None then error
block_table=torch.empty(
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
),
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
)
# TODO shape correct?
return o
def _forward_tilelang(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
return tilelang_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> NSAIndexerMetadata:
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
from flash_mla import get_mla_metadata
flashmla_metadata, num_splits = get_mla_metadata(
cache_seqlens=cache_seqlens,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
num_heads_k=1,
num_heads_q=self.num_q_heads,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
topk=self.nsa_index_topk,
)
return NSAFlashMLAMetadata(
flashmla_metadata=flashmla_metadata,
num_splits=num_splits,
)
# TODO speedup
def _compute_indices_in_kvcache(block_table, topk_indices, page_size, nsa_index_topk):
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
1
)
block_idx = block_table[idx0, topk_indices_safe // page_size]
offset = topk_indices_safe % page_size
indices_in_kvcache = block_idx * page_size + offset
# the kernel requires invalid entry to be -1
assert indices_in_kvcache.shape == topk_indices.shape
indices_in_kvcache[topk_indices == -1] = -1
# return: (batch_size, seqlen_q_ori, topk)
indices_in_kvcache = indices_in_kvcache[:, None, :]
indices_in_kvcache = torch.nn.functional.pad(
indices_in_kvcache,
(0, nsa_index_topk - indices_in_kvcache.shape[-1]),
"constant",
-1,
)
assert indices_in_kvcache.shape[-1] == nsa_index_topk
return indices_in_kvcache
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment