Unverified Commit efbc687c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files
parent 292a867a
......@@ -49,6 +49,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS = "transformers"
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
and getattr(config, "index_topk", None) is not None
)
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_head_dim
def get_nsa_index_topk(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_topk
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_n_heads
class ModelConfig:
def __init__(
self,
......@@ -271,6 +295,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
......@@ -283,6 +308,11 @@ class ModelConfig:
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.index_head_dim = (
get_nsa_index_head_dim(self.hf_config)
if is_deepseek_nsa(self.hf_config)
else None
)
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
......
......@@ -2,9 +2,19 @@ import logging
import os
from typing import List, Optional
import torch
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
try:
from mf_adapter import TransferEngine
import_error = None
except ImportError as e:
import_error = e
pass
logger = logging.getLogger(__name__)
......@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
):
try:
from mf_adapter import TransferEngine
except ImportError as e:
raise ImportError(
if import_error is not None:
logger.warning(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
) from e
)
raise import_error
self.engine = TransferEngine()
self.hostname = hostname
......@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
self.initialize()
def initialize(self) -> None:
from sglang.srt.layers.dp_attention import (
get_tensor_model_parallel_world_size,
get_tp_group,
)
transfer_protocol = self._get_transfer_protocol()
if transfer_protocol is None or transfer_protocol == "sdma":
trans_op_type = TransferEngine.TransDataOpType.SDMA
else:
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
"""with device RDMA for PD transfer"""
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(get_tensor_model_parallel_world_size())
]
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
torch.distributed.all_gather(
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
)
"""Initialize the ascend transfer instance."""
ret_value = self.engine.initialize(
self.store_url,
self.session_id,
self.role,
self.npu_id,
self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
)
if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.")
......@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
ret_value = -1
if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
@staticmethod
def _get_transfer_protocol():
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
allowed_protocols = {"device_rdma", "sdma"}
if protocol and protocol.lower() in allowed_protocols:
return protocol.lower()
else:
logger.warning(
"Invalid or no transfer protocol specified, using default protocol."
)
return None
......@@ -36,6 +36,8 @@ class ForwardMetadata:
seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None
seq_lens: Optional[torch.Tensor] = None
actual_seq_lengths_q: Optional[torch.Tensor] = None
class AscendAttnBackend(AttentionBackend):
......@@ -67,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.q_head_dim = (
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
)
self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len
......@@ -102,10 +107,6 @@ class AscendAttnBackend(AttentionBackend):
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
if forward_batch.is_extend_in_batch:
seq_lens_list_cumsum[-1] = (
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False
......@@ -133,6 +134,10 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
metadata.seq_lens = seq_lens
metadata.actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
)
self.graph_metadata[bs] = metadata
self.forward_metadata = metadata
......@@ -161,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0)
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
self.forward_metadata = metadata
self.graph_mode = True
......@@ -168,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_sparse(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: torch.Tensor = None,
):
is_prefill = forward_batch.forward_mode.is_extend()
if save_kv_cache:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
q_nope, q_pe = q, q_rope
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
block_table = self.forward_metadata.block_tables
if is_prefill:
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
else:
if self.forward_metadata.actual_seq_lengths_q is None:
actual_seq_qlen = (
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
)
else:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
attn_out = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
query_rope=q_pe,
key_rope=k_pe,
sparse_indices=topk_indices,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
block_table=block_table,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
return attn_out
def forward_extend(
self,
q,
......@@ -176,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
......@@ -437,10 +518,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache
save_kv_cache = False
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if self.graph_mode:
return self.forward_decode_graph(
......
......@@ -66,6 +66,13 @@ def create_ascend_backend(runner):
return AscendAttnBackend(runner)
@register_attention_backend("nsa")
def create_nsa_backend(runner):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(runner)
@register_attention_backend("triton")
def create_triton_backend(runner):
assert not runner.model_config.is_encoder_decoder, (
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
......@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def support_triton(self):
"""Check if the current backend supports triton."""
return True
def get_indexer_metadata(
self,
layer_id: int,
forward_batch: ForwardBatch,
) -> Optional[BaseIndexerMetadata]:
"""Get the indexer metadata. None means don't support indexer."""
return None
......@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> Optional[BaseIndexerMetadata]:
backend = self._select_backend(forward_batch.forward_mode)
return backend.get_indexer_metadata(layer_id, forward_batch)
......@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self.rotary_emb = rotary_emb
self.layer_id = layer_id
self.has_preprocess_weights = False
self.dtype = None
self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
def preprocess_weights(self, hidden_states):
self.dummy = torch.empty(
......@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
def forward_absorb_prepare_npu_rms_norm_cache(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch,
zero_allocator,
):
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
self.dtype = hidden_states.dtype
self.cos, self.sin = self.get_sin_cos(positions)
self.kvCache, self.kvCacheRope, self.slotmapping = (
self.get_kv_cache_and_cache_idx(forward_batch)
)
if not self.has_preprocess_weights:
self.has_preprocess_weights = True
cos, sin = self.cos, self.sin
if self.q_lora_rank is not None:
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q_lowrank)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # b*s,n,d
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
latent_cache = latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (B*S,N,1,D)
cache_mode = "PA_BNSD"
self.kvCache = self.kvCache.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.kv_lora_rank,
)
self.kvCacheRope = self.kvCacheRope.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.qk_rope_head_dim,
)
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
latent_cache,
self.kv_a_layernorm.weight,
cos,
sin,
self.slotmapping.to(torch.int64),
self.kvCacheRope,
self.kvCache,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype
if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states)
......@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator,
positions,
)
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
_is_w8a8 = (
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
and self.qkv_a_proj.quant_method.quantization_config.get_name()
== "w8a8_int8"
)
if _is_w8a8:
return self.forward_mlapo(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb_prepare_npu_rms_norm_cache(
positions, hidden_states, forward_batch, zero_allocator
)
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
def dequantize_k_cache(quant_k_cache):
if NSA_DEQUANT_K_CACHE_FAST:
return _dequantize_k_cache_fast_wrapped(quant_k_cache)
else:
return _dequantize_k_cache_slow(quant_k_cache)
def _dequantize_k_cache_slow(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576,
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty(
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_nope * cur_scales
)
result = result.view(num_blocks, block_size, 1, d)
return result
def _dequantize_k_cache_fast_wrapped(
quant_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
assert dv == 512
assert dim_quant == 656
assert tile_size == 128
quant_k_cache = quant_k_cache.view((-1, dim_quant))
output = _dequantize_k_cache_fast(quant_k_cache)
return output.view(num_blocks, block_size, 1, -1)
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
num_tokens, dim_quant = quant_k_cache.shape
assert quant_k_cache.dtype == torch.float8_e4m3fn
dim_nope = 512
dim_rope = 64
num_tiles = dim_nope // group_size
assert dim_quant == 656
output = torch.empty(
(num_tokens, dim_nope + dim_rope),
dtype=torch.bfloat16,
device=quant_k_cache.device,
)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
input_nope_q = quant_k_cache[:, :dim_nope]
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
torch.float32
)
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
_dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output,
input_nope_q,
input_nope_s,
input_rope,
output.stride(0),
input_nope_q.stride(0),
input_nope_s.stride(0),
input_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
)
return output
@triton.jit
def _dequantize_k_cache_fast_kernel(
output_ptr,
input_nope_q_ptr,
input_nope_s_ptr,
input_rope_ptr,
output_stride_0: int,
input_nope_q_stride_0: int,
input_nope_s_stride_0: int,
input_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. dequant nope
effective_block_id = raw_block_id
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs_q < DIM_NOPE
ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(ptr_s)
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
tl.store(dst_ptr, y, mask=mask)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
raise Exception("UT is in quant_k_cache.py")
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
"""
k: data, 128 item per token, fp8
s: scale, 1 item per token, fp32
"""
class GetK:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
index_k_fp8 = torch.empty(
(seq_len_, pool.index_head_dim),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
return index_k_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim), uint8
"""
# can handle per 128B instead of per element
# page_indices: (num_pages,), element := a page index
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_page = pool.page_size * pool.index_head_dim
num_k_bytes_per_token = pool.index_head_dim
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
# flat_buf: (whatever,), uint8
flat_buf = buf.flatten()
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
num_k_bytes_per_page, dtype=torch.int32, device="cuda"
)[None, :]
flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 128)
class GetS:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
assert pool.index_head_dim // pool.quant_block_size == 1
index_k_scale_fp8 = torch.empty(
(seq_len_, 4),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][pool.page_size * pool.index_head_dim :].view(-1, 4)
return index_k_scale_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim // quant_block_size), uint8
"""
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
s_offset_in_page = pool.page_size * pool.index_head_dim
flat_buf = buf.flatten()
flat_indices = (
(page_indices * buf_numel_per_page)[:, None]
+ torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
None, :
]
+ s_offset_in_page
)
flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 4)
class SetK:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
buf[
page_index,
offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
] = index_k[i].view(torch.uint8)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_token = pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
+ torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
flat_indices = flat_indices.flatten()[:num_k_bytes_total]
flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
class SetS:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
start = pool.page_size * pool.index_head_dim
buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
index_k_scale[i].view(torch.uint8)
)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_token = 4
s_offset_in_page = pool.page_size * pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ s_offset_in_page
+ (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
+ torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
flat_indices = flat_indices.flatten()[:number_s_bytes_total]
flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
class SetKAndS:
@classmethod
def execute(cls, *args, buf, **kwargs):
if 0:
# print("SetK, SetS comparison test")
buf_cloned = buf.clone()
cls.vanilla(*args, **kwargs, buf=buf)
cls.triton(*args, **kwargs, buf=buf_cloned)
def _clear_token_0(target):
target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
_clear_token_0(buf)
_clear_token_0(buf_cloned)
assert torch.all(
buf == buf_cloned
), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
return
cls.triton(*args, **kwargs, buf=buf)
@classmethod
def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
@classmethod
def triton(cls, pool, buf, loc, index_k, index_k_scale):
_set_k_and_s_triton(
buf=buf,
loc=loc,
index_k=index_k,
index_k_scale=index_k_scale,
page_size=pool.page_size,
)
def _set_k_and_s_triton(
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
page_size: int,
):
"""
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
:param loc: (num_tokens_to_write,), int, element := the token index to write to
:param index_k: (num_tokens_to_write, 128 elem), fp8
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
:return:
"""
num_pages, buf_numel_per_page = buf.shape
(num_tokens_to_write,) = loc.shape
num_tokens_to_write_, index_head_dim = index_k.shape
num_tokens_to_write__, scale_dim = index_k_scale.shape
assert buf_numel_per_page == 64 * (128 + 4)
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
assert index_head_dim == 128
assert scale_dim == 1
assert page_size == 64
assert buf.dtype == torch.uint8
assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
assert index_k.dtype == torch.float8_e4m3fn
assert index_k_scale.dtype == torch.float32
assert buf.is_contiguous()
assert loc.is_contiguous()
assert index_k.is_contiguous()
assert index_k_scale.is_contiguous()
buf_fp8 = buf.view(torch.float8_e4m3fn)
buf_fp32 = buf.view(torch.float32)
_set_k_and_s_triton_kernel[(num_tokens_to_write,)](
buf_fp8,
buf_fp32,
loc,
index_k,
index_k_scale,
index_k.stride(0),
PAGE_SIZE=page_size,
BUF_NUMEL_PER_PAGE=buf_numel_per_page,
NUM_K_ELEMS_PER_TOKEN=index_head_dim,
S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
)
@triton.jit
def _set_k_and_s_triton_kernel(
buf_fp8_ptr,
buf_fp32_ptr,
loc_ptr,
index_k_ptr,
index_k_scale_ptr,
index_k_ptr_stride_0,
PAGE_SIZE: tl.constexpr,
BUF_NUMEL_PER_PAGE: tl.constexpr,
NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
):
token_id = tl.program_id(0)
loc = tl.load(loc_ptr + token_id)
in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
k = tl.load(index_k_ptr + in_k_offsets)
k_scale = tl.load(index_k_scale_ptr + token_id)
loc_page_index = loc // PAGE_SIZE
loc_token_offset_in_page = loc % PAGE_SIZE
out_k_offsets = (
loc_page_index * BUF_NUMEL_PER_PAGE
+ loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
+ tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
)
# "//4" b/c it is fp32 instead of uint8
out_s_offset = (
loc_page_index * BUF_NUMEL_PER_PAGE // 4
+ S_OFFSET_NBYTES_IN_PAGE // 4
+ loc_token_offset_in_page
)
tl.store(buf_fp8_ptr + out_k_offsets, k)
tl.store(buf_fp32_ptr + out_s_offset, k_scale)
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
if is_cuda():
import deep_gemm
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
class BaseIndexerMetadata(ABC):
@abstractmethod
def get_seqlens_int32(self) -> torch.Tensor:
"""
Return: (batch_size,) int32 tensor
"""
@abstractmethod
def get_page_table_64(self) -> torch.Tensor:
"""
Return: (batch_size, num_blocks) int32, page table.
The page size of the table is 64.
"""
@abstractmethod
def get_seqlens_expanded(self) -> torch.Tensor:
"""
Return: (sum_extend_seq_len,) int32 tensor
"""
@abstractmethod
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""
Perform topk selection on the logits and possibly transform the result.
NOTE that attention backend may override this function to do some
transformation, which means the result of this topk_transform may not
be the topk indices of the input logits.
Return: Anything, since it will be passed to the attention backend
for further processing on sparse attention computation.
Don't assume it is the topk indices of the input logits.
"""
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
from fast_hadamard_transform import hadamard_transform
hidden_size = x.size(-1)
assert (
hidden_size & (hidden_size - 1)
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
return hadamard_transform(x, scale=hidden_size**-0.5)
class V32LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(
x.float(), (self.dim,), self.weight, self.bias, self.eps
).type_as(x)
class Indexer(CustomOp):
def __init__(
self,
hidden_size: int,
index_n_heads: int,
index_head_dim: int,
rope_head_dim: int,
index_topk: int,
q_lora_rank: int,
max_position_embeddings: int,
rope_theta: float,
layer_id: int,
scale_fmt: Optional[str],
block_size: int = 128,
rope_scaling: Optional[Dict[str, Any]] = None,
prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = index_n_heads
self.head_dim = index_head_dim
self.rope_head_dim = rope_head_dim
self.index_topk = index_topk
self.q_lora_rank = q_lora_rank
self.layer_id = layer_id
self.alt_stream = alt_stream
if is_cuda():
self.sm_count = deep_gemm.get_num_sms()
self.half_device_sm_count = align(self.sm_count // 2, 8)
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wq_b", prefix),
)
self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
self.k_norm = V32LayerNorm(self.head_dim)
# NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
prefix=add_prefix("weights_proj", prefix),
)
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
rotary_dim=rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
self.block_size = block_size
self.scale_fmt = scale_fmt
self.softmax_scale = self.head_dim**-0.5
def _forward_fake(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
):
bs = x.shape[0]
assert self.index_topk == 2048
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
None, ...
].repeat(bs, 1)
if forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.seq_lens_cpu is not None
)
which = 0
for i, (kv_len, qo_len) in enumerate(
zip(
forward_batch.seq_lens_cpu.tolist(),
forward_batch.extend_seq_lens_cpu,
strict=True,
)
):
for j in range(kv_len - qo_len, kv_len):
ans[which, j + 1 :] = -1
which += 1
assert which == ans.shape[0]
else:
assert forward_batch.seq_lens_cpu is not None
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
ans[i, seq_len:] = -1
return ans
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
def _get_q_k_bf16(
self,
q_lora: torch.Tensor,
x: torch.Tensor,
positions: torch.Tensor,
enable_dual_stream: bool,
):
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
self.half_device_sm_count
):
query, _ = self.wq_b(q_lora)
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
with torch.cuda.stream(self.alt_stream):
# TODO we should also put DeepGEMM half SM here?
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
current_stream.wait_stream(self.alt_stream)
else:
query, _ = self.wq_b(q_lora)
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
query[..., : self.rope_head_dim] = q_rope
key[..., : self.rope_head_dim] = k_rope
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
query = rotate_activation(query)
with torch.cuda.stream(self.alt_stream):
key = rotate_activation(key)
current_stream.wait_stream(self.alt_stream)
else:
query = rotate_activation(query)
key = rotate_activation(key)
return query, key
def _get_topk_paged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm
assert page_size == 64, "only support page size 64"
# NOTE(dark): this support extend/decode/decode+graph
block_tables = metadata.get_page_table_64()
max_seq_len = block_tables.shape[1] * page_size
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
layer_id=layer_id
)
blocksize = page_size
seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
seqlens_32, blocksize, self.sm_count
)
assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
assert len(kv_cache_fp8.shape) == 2
block_kv = 64
num_heads_kv = 1
head_dim_with_sf = 132
kv_cache_fp8 = kv_cache_fp8.view(
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
)
assert len(weights.shape) == 3
weights = weights.squeeze(2)
logits = deep_gemm.fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
seqlens_32,
block_tables,
schedule_metadata,
max_seq_len,
clean_logits=False,
)
# NOTE(dark): logits should be cleaned in topk_transform
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def _get_topk_ragged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
assert page_size == 64, "only support page size 64"
assert len(weights.shape) == 3
weights = weights.squeeze(-1)
k_fp8_list = []
k_scale_list = []
ks_list = []
offset = 0
block_tables = metadata.get_page_table_64()
assert (
forward_batch.seq_lens_cpu is not None
and forward_batch.extend_seq_lens_cpu is not None
)
for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens_cpu[i].item()
assert isinstance(seq_len, int)
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
layer_id,
seq_len,
block_tables[i],
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
offset += extend_seq_len
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded
logits = deep_gemm.fp8_mqa_logits(
q_fp8,
kv_fp8,
weights,
ks,
ke,
clean_logits=False,
)
assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def forward_indexer_bs_1(
self,
q_fp8: torch.Tensor,
weights: torch.Tensor,
forward_batch: ForwardBatch,
topk: int,
layer_id: int,
) -> Optional[torch.Tensor]:
if not is_npu():
from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index
page_size = forward_batch.token_to_kv_pool.page_size
assert page_size == 64, "only support page size 64"
assert len(weights.shape) == 3
weights = weights.squeeze(-1)
# logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
k_fp8_list = []
k_scale_list = []
topk_indices_list = []
block_tables = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :
]
strided_indices = torch.arange(
0, block_tables.shape[-1], page_size, device="cuda"
)
block_tables = block_tables[:, strided_indices] // page_size
q_len_start = 0
for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens[i].item()
q_len = (
forward_batch.extend_seq_lens_cpu[i]
if forward_batch.forward_mode.is_extend()
else 1
)
q_len_end = q_len_start + q_len
q_fp8_partial = q_fp8[q_len_start:q_len_end]
q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous()
weights_partial = weights[q_len_start:q_len_end]
weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous()
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous()
k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous()
index_score = fp8_index(
q_fp8_partial,
weights_partial,
k_fp8,
k_scale,
)
end_pos = seq_len
topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0)
pad_len = align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1]
topk_indices = torch.nn.functional.pad(
topk_indices, (0, pad_len), "constant", -1
)
topk_indices_list.append(topk_indices)
q_len_start = q_len_end
topk_indices = torch.cat(topk_indices_list, dim=0)
return topk_indices
def forward_indexer(
self,
q_fp8: torch.Tensor,
weights: torch.Tensor,
forward_batch: ForwardBatch,
topk: int,
layer_id: int,
) -> Optional[torch.Tensor]:
return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
def _forward(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
if not is_npu():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
metadata = forward_batch.attn_backend.get_indexer_metadata(
layer_id, forward_batch
)
enable_dual_stream = (
NSA_DUAL_STREAM
and self.alt_stream is not None
and get_is_capture_mode()
and q_lora.shape[0] > 0
and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
)
# skip NSA if attention backend choose to skip this batch
if metadata is None:
return None
if not NSA_USE_REAL_INDEXER: # temporary
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
with torch.cuda.stream(self.alt_stream):
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
current_stream.wait_stream(self.alt_stream)
else:
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
weights = self._get_logits_head_gate(x, q_scale)
if is_cuda():
assert forward_batch.seq_lens_cpu is not None
if len(forward_batch.seq_lens_cpu) == 0:
# this seems b/c max-pad, no worries?
# if x.shape[0] != 0:
# print(
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
# )
return torch.full(
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
)
if forward_batch.forward_mode.is_decode_or_idle():
topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata
)
else:
topk_result = self._get_topk_ragged(
forward_batch, layer_id, q_fp8, weights, metadata
)
else:
topk_result = self.forward_indexer(
q_fp8.contiguous(),
weights,
forward_batch,
topk=self.index_topk,
layer_id=layer_id,
)
return topk_result
def forward_cuda(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
return self._forward(x, q_lora, positions, forward_batch, layer_id)
def forward_npu(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> torch.Tensor:
import custom_ops
import torch_npu
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.utils import get_bool_env_var
if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = (
forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
)
enable_index_cp = (
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
)
is_prefill = forward_batch.forward_mode.is_extend()
attention_tp_rank = get_attention_tp_rank()
attention_tp_size = get_attention_tp_size()
cos_sin = self.rotary_emb.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
if is_prefill and enable_index_cp:
slice_length = cos.shape[0] // attention_tp_size
cos = cos[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
sin = sin[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
slot_mapping = forward_batch.out_cache_loc
block_table = forward_batch.attn_backend.forward_metadata.block_tables
bs = x.shape[0]
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
q_pe, q_nope = torch.split(
q,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64, 64 + 64]
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
bs, self.n_heads, self.rope_head_dim
) # [bs, n, d]
q = torch.cat([q_pe, q_nope], dim=-1)
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
k = self.k_norm(k_proj)
k_pe, k_nope = torch.split(
k,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64 + 64]
k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
bs, 1, self.rope_head_dim
) # [bs, 1, d]
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
if is_prefill and enable_index_cp:
k, local_k = (
torch.empty(
(k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
dtype=k.dtype,
device=k.device,
),
k,
)
get_attention_tp_group().all_gather_into_tensor(k, local_k)
forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
indexer_input = {}
if is_prefill:
actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
device=q.device
)
if enable_index_cp:
actual_seq_lengths_q -= bs * attention_tp_rank
actual_seq_lengths_q = torch.max(
actual_seq_lengths_q,
torch.zeros_like(actual_seq_lengths_q).to(
device=actual_seq_lengths_q.device
),
)
actual_seq_lengths_q = torch.min(
actual_seq_lengths_q,
torch.full(actual_seq_lengths_q.shape, bs).to(
device=actual_seq_lengths_q.device
),
)
else:
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
)
else:
actual_seq_lengths_q = (
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
)
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x)[0]
block_table = (
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
)
topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q.view(-1, self.n_heads, self.head_dim),
key=past_key_states,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=self.index_topk,
sparse_mode=3,
)
if is_prefill and enable_index_cp:
topk_indices, local_topk_indices = (
torch.empty(
(
topk_indices.shape[0] * attention_tp_size,
topk_indices.shape[1],
topk_indices.shape[2],
),
dtype=topk_indices.dtype,
device=topk_indices.device,
),
topk_indices,
)
get_attention_tp_group().all_gather_into_tensor(
topk_indices, local_topk_indices
)
return topk_indices
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
def quantize_k_cache(cache_k):
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
if NSA_QUANT_K_CACHE_FAST:
return _quantize_k_cache_fast_wrapped(cache_k)
else:
return _quantize_k_cache_slow(cache_k)
# Copied from original
def _quantize_k_cache_slow(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty(
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
dtype=torch.float8_e4m3fn,
device=input_k_cache.device,
)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = (
torch.abs(
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
)
.max(dim=-1)
.values
/ 448.0
) # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (
input_k_cache[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].float()
/ cur_scale_factors_inv.float()
).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_quantized_nope
)
result = result.view(num_blocks, block_size, 1, -1)
return result
def _quantize_k_cache_fast_wrapped(
input_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
assert dv == 512
assert dim_nope_and_rope == 512 + 64
assert tile_size == 128
input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
k_nope = input_k_cache[:, :dv]
k_rope = input_k_cache[:, dv:]
output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
return output.view(num_blocks, block_size, 1, -1)
def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
"""
:param k_nope: (num_tokens, dim_nope 512)
:param k_rope: (num_tokens, dim_rope 64)
"""
assert k_nope.dtype == torch.bfloat16
assert k_rope.dtype == torch.bfloat16
num_tokens, dim_nope = k_nope.shape
num_tokens_, dim_rope = k_rope.shape
assert num_tokens == num_tokens_
assert dim_nope == 512
assert dim_rope == 64
assert k_nope.dtype == k_rope.dtype
num_tiles = dim_nope // group_size
assert k_nope.stride(1) == 1
assert k_rope.stride(1) == 1
output = torch.empty(
(num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
dtype=torch.float8_e4m3fn,
device=k_nope.device,
)
output_nope_q = output[..., :dim_nope]
output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
_quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output_nope_q,
output_nope_s,
output_rope,
k_nope,
k_rope,
output_nope_q.stride(0),
output_nope_s.stride(0),
output_rope.stride(0),
k_nope.stride(0),
k_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
)
return output
@triton.jit
def _quantize_k_cache_fast_kernel(
output_nope_q_ptr,
output_nope_s_ptr,
output_rope_ptr,
k_nope_ptr,
k_rope_ptr,
output_nope_q_stride_0: int,
output_nope_s_stride_0: int,
output_rope_stride_0: int,
k_nope_stride_0: int,
k_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
FP8_MIN: tl.constexpr,
FP8_MAX: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. quant nope
effective_block_id = raw_block_id
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_NOPE
ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
y_s = tl.max(tl.abs(y)) / FP8_MAX
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
output_nope_q_ptr.dtype.element_ty
)
dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
dst_s_ptr = (
output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
)
tl.store(dst_q_ptr, y_q, mask=mask)
tl.store(dst_s_ptr, y_s)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
data = tl.load(src_ptr, mask=mask)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
for num_blocks, block_size in [
(1, 1),
(10, 64),
]:
dim_nope_and_rope = 512 + 64
input_k_cache = torch.randn(
(num_blocks, block_size, 1, dim_nope_and_rope),
dtype=torch.bfloat16,
device="cuda",
)
# temp debug
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
ref_quant = _quantize_k_cache_slow(input_k_cache)
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
# print(f"{input_k_cache=}")
# print(f"{ref_quant=}")
# print(f"{actual_quant=}")
# print(f"{ref_quant == actual_quant=}")
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
# print(f"{ref_quant.view(torch.bfloat16)=}")
# print(f"{actual_quant.view(torch.bfloat16)=}")
# assert torch.all(ref_quant == actual_quant)
import dequant_k_cache
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
actual_quant
)
print(f"{ref_ref_dequant=}")
print(f"{actual_actual_dequant=}")
print(f"{actual_actual_dequant - ref_ref_dequant=}")
print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
# TODO too different?
torch.testing.assert_close(
ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
)
torch.testing.assert_close(
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
)
print("Passed")
from typing import Optional, Tuple
import tilelang
import tilelang.language as T
import torch
from sglang.srt.utils import is_hip
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
}
BF16 = "bfloat16"
FP8 = "float8_e4m3"
FP32 = "float32"
_is_hip = is_hip()
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
):
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale else 2
blk_m = 32
group_size = 128
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
s_local = T.alloc_fragment((blk_m,), scale_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert (
x.size(-1) % block_size == 0
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
return y, s
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int, clear_accum=True):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
blk_n1 = 512
blk_n2 = 128
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
q_s_frag = T.alloc_fragment(h, FP32)
T.copy(q_s[i_b, i_m, 0], q_s_frag)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
logits = T.alloc_fragment((blk_n2, h), FP32)
T.gemm(
k_smem,
q_smem,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=clear_accum,
)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
logits_sum = T.alloc_fragment(blk_n2, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
return fp8_index_kernel_
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
if _is_hip:
return fp8_index_kernel(q.shape[2], q.shape[3], False)(q, q_s, k, k_s)
else:
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_attention_fwd_kernel_v1(
num_heads,
dim,
tail_dim,
topk,
*,
kv_group=1,
sm_scale=None,
is_causal=True,
block_I=64,
num_stages=2,
threads=256,
):
assert dim == tilelang.math.next_power_of_2(
dim
), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim
), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert (
topk % block_I == 0
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = num_heads // kv_group
q_shape = [batch, seq_len, num_heads, dim + tail_dim]
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
o_shape = [batch, seq_len, num_heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
bx,
by,
bz,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
O_shared = T.alloc_shared([H_per_block, D], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i
]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i
]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
mask[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, O_shared)
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
return main
@tilelang.jit(
out_idx=[-1],
compile_flags=[
"-O3",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10",
"-DNDEBUG",
],
) # type: ignore
def sparse_attention_fwd_kernel_v2(
num_heads: int,
dim: int,
tail_dim: int,
topk: int,
*,
kv_group: int = 1,
sm_scale: Optional[float] = None,
block_I: int = 64,
):
assert dim == tilelang.math.next_power_of_2(
dim
), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim
), f"haven't check padding correctness yet, dim={tail_dim}"
assert (
topk % block_I == 0
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
threads = 384
batch = T.symbolic("batch")
qo_len = T.symbolic("seq_len")
num_pages = T.symbolic("num_pages")
q_shape = [batch, qo_len, num_heads, dim + tail_dim]
kv_shape = [batch, num_pages, kv_group, dim + tail_dim]
o_shape = [batch, qo_len, num_heads, dim]
indices_shape = [batch, qo_len, kv_group, topk]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
H = num_heads
padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)
if padded_H != H:
assert kv_group == 1
BI = block_I
NI = tilelang.cdiv(topk, block_I)
assert NI % 2 == 0, "NI should be a multiple of 2"
D = dim
D_tail = tail_dim
if num_heads > 64:
assert num_heads % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = num_heads // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
):
"""
Q: [b, qo_len, H, D + D_tail] (bfloat16)
KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
Indices: [b, qo_len, kv_group, topk] (int32)
"""
with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
O_shared_l = Q_shared_l
O_shared_r = Q_shared_r
is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared")
is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared")
acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
indices_local = T.alloc_local([1], indices_dtype)
indices_tmp = T.alloc_local([1], indices_dtype)
bar_q = T.alloc_barrier(arrive_count=384)
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
bar_k_0_free = T.alloc_barrier(arrive_count=256)
bar_k_1_free = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
bar_0_128 = T.alloc_barrier(arrive_count=128)
bar_1_128 = T.alloc_barrier(arrive_count=128)
bar_2_128 = T.alloc_barrier(arrive_count=128)
bar_final = T.alloc_barrier(arrive_count=128)
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
tx = T.get_thread_binding()
T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
# with sync_at(bar_0_128, 0):
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
T.barrier_arrive(bar_0_128)
T.barrier_wait(bar_0_128, 0)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_tail_shared,
K_tail_shared_0,
acc_s,
transpose_B=True,
wg_wait=-1,
)
T.wait_wgmma(0)
if i_i != 0:
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(
acc_s, sumexp_i, dim=1
) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_0_free[0])
# Buffer 1
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
T.barrier_arrive(bar_0_128)
T.barrier_wait(bar_0_128, 1)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_tail_shared,
K_tail_shared_1,
acc_s,
transpose_B=True,
wg_wait=-1,
)
T.wait_wgmma(0)
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(
acc_s, sumexp_i, dim=1
) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_1_free[0])
# Rescale
for h_i in T.Parallel(H_per_block):
sum_exp_shared[h_i] = sumexp[h_i]
T.barrier_arrive(bar_final)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
elif tx >= 128 and tx < 256:
# T.set_max_nreg(168, 1)
T.fill(acc_o_r, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
T.barrier_arrive(bar_1_128)
T.barrier_wait(bar_1_128, 0)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
T.barrier_arrive(bar_k_0_free[0])
T.barrier_arrive(bar_sScale_and_sS_free)
# Buffer 1
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
T.barrier_arrive(bar_1_128)
T.barrier_wait(bar_1_128, 1)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
T.barrier_arrive(bar_k_1_free[0])
if i_i != T.ceildiv(NI, 2) - 1:
T.barrier_arrive(bar_sScale_and_sS_free)
# Rescale
T.barrier_wait(bar_final, 0)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
elif tx >= 256:
# producer
T.set_max_nreg(80, 0)
indices_local[0] = 0
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
T.barrier_arrive(bar_2_128)
T.barrier_wait(bar_2_128, 0)
for r in T.serial(4):
indices_tmp[0] = Indices[
b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8
]
is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
if is_kv_valid_0[r * 16 + (tx - 256) // 8]:
indices_local[0] = indices_tmp[0]
with T.attr("default", "async_scope", 1): # type: ignore
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
64 * u + (tx - 256) % 8 * 8 + v,
]
KV_shared_0_r[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
]
with T.attr("default", "async_scope", 1): # type: ignore
for v in T.vectorized(8):
K_tail_shared_0[
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
] = KV[
b_i,
indices_local[0],
g_i,
D + (tx - 256) % 8 * 8 + v,
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
T.barrier_arrive(bar_2_128)
T.barrier_wait(bar_2_128, 1)
for r in T.serial(4):
indices_tmp[0] = Indices[
b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8
]
is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
if is_kv_valid_1[r * 16 + (tx - 256) // 8]:
indices_local[0] = indices_tmp[0]
with T.attr("default", "async_scope", 1): # type: ignore
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
64 * u + (tx - 256) % 8 * 8 + v,
]
KV_shared_1_r[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
]
with T.attr("default", "async_scope", 1): # type: ignore
for v in T.vectorized(8):
K_tail_shared_1[
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
] = KV[
b_i,
indices_local[0],
g_i,
D + (tx - 256) % 8 * 8 + v,
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
return main
def tilelang_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> torch.Tensor:
assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3
num_heads = q.shape[1]
dim = q.shape[2]
tail_dim = dim - d_v
topk = indices.shape[-1]
assert topk == 2048
if _is_hip:
kernel = sparse_attention_fwd_kernel_v1(
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale, num_stages=1
)
else:
kernel = sparse_attention_fwd_kernel_v2(
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
)
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
from typing import List, Optional
import torch
import triton
import triton.language as tl
def transform_index_page_table_prefill(**kwargs):
return transform_index_page_table_prefill_ref(**kwargs)
def transform_index_page_table_decode(**kwargs):
return transform_index_page_table_decode_ref(**kwargs)
@triton.jit
def transform_index_page_table_decode_kernel(
page_table_ptr: torch.Tensor,
topk_indices_ptr: torch.Tensor,
result_ptr: torch.Tensor,
page_size: tl.constexpr,
max_seqlen_k: tl.constexpr,
):
TOPK: tl.constexpr = 2048
req_id = tl.program_id(0)
page_table_ptr = page_table_ptr + req_id * max_seqlen_k
topk_indices_ptr = topk_indices_ptr + req_id * TOPK
result_ptr = result_ptr + req_id * TOPK
offset = tl.arange(0, TOPK) # topk should be 2048
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
mask = loaded_topk_indices >= 0
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
tl.store(result_ptr + offset, -1, mask=~mask)
def transform_index_page_table_decode_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
"""
Transform the page table according to topk indices for sparse topk attention.
Args:
page_table: [qo_len, max_seqlen_k], the original page table
topk_indices: [qo_len, topk], the topk indices for each query position
Returns:
transformed_page_table: [qo_len, topk], the transformed page table
For out-of-bound indices in topk_indices, this should be filled with -1.
"""
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
assert topk_indices.shape[1] == 2048
qo_len = topk_indices.shape[0]
max_seqlen_k = page_table.shape[1]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
# Launch triton kernel
grid = (qo_len,)
transform_index_page_table_decode_kernel[grid](
page_table,
topk_indices,
result,
page_size,
max_seqlen_k=max_seqlen_k,
)
return result
def transform_index_page_table_prefill_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
# TODO(baizhou): can be implemented with another triton kernel
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_fast(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
def transform_index_page_table_decode_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert result.shape == topk_indices.shape
torch.gather(
page_table,
dim=1,
index=topk_indices.clamp(min=0),
out=result,
)
result[topk_indices < 0] = -1
return result
def transform_index_page_table_prefill_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_ref(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
if __name__ == "__main__":
bs, topk, max_seqlen = 10, 2048, 3000
page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
topk_indices = torch.full((bs, topk), -1, device="cuda")
topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
result = transform_index_page_table_decode_fast(page_table, topk_indices)
assert torch.all(result == ref_result)
print("Passed")
# temp NSA debugging environ
from sglang.srt.utils import get_bool_env_var
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
)
NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "true")
NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "true")
def print_nsa_bool_env_vars():
msg = ""
for k, v in globals().items():
if k.startswith("NSA_") and isinstance(v, bool):
msg += f"{k}={v} "
print(msg, flush=True)
def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
return original_seq_lens.clamp(max=nsa_index_topk)
from __future__ import annotations
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
import torch
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.layers.attention.nsa.transform_index import (
transform_index_page_table_decode,
transform_index_page_table_prefill,
)
from sglang.srt.layers.attention.nsa.utils import (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
NSA_FUSE_TOPK,
compute_nsa_seqlens,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
_is_hip = is_hip()
if _is_hip:
try:
from aiter import (
flash_attn_varlen_func,
mha_batch_prefill_func,
paged_attention_ragged,
)
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else:
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass(frozen=True)
class NSAFlashMLAMetadata:
"""Metadata only needed by FlashMLA"""
flashmla_metadata: torch.Tensor
num_splits: torch.Tensor
def slice(self, sli):
return NSAFlashMLAMetadata(
flashmla_metadata=self.flashmla_metadata,
num_splits=self.num_splits[sli],
)
def copy_(self, other: "NSAFlashMLAMetadata"):
self.flashmla_metadata.copy_(other.flashmla_metadata)
self.num_splits.copy_(other.num_splits)
@dataclass(frozen=True)
class NSAMetadata:
page_size: int
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor
# Maximum sequence length for query
max_seq_len_q: int
# Maximum sequence length for key
max_seq_len_k: int
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor
# Page table, the index of KV Cache Tables/Blocks
# this table is always with page_size = 1
page_table_1: torch.Tensor
# NOTE(dark): This will property be used in:
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
real_page_table: torch.Tensor
# NSA metadata (nsa prefill are expanded)
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
nsa_extend_seq_lens_list: List[int]
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
@dataclass(frozen=True)
class NSAIndexerMetadata(BaseIndexerMetadata):
attn_metadata: NSAMetadata
def get_seqlens_int32(self) -> torch.Tensor:
return self.attn_metadata.cache_seqlens_int32
def get_page_table_64(self) -> torch.Tensor:
return self.attn_metadata.real_page_table
def get_seqlens_expanded(self) -> torch.Tensor:
return self.attn_metadata.nsa_seqlens_expanded
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
if not NSA_FUSE_TOPK:
return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
# NOTE(dark): if fused, we return a transformed page table directly
return fast_topk_transform_fused(
score=logits,
lengths=self.get_seqlens_expanded(),
page_table_size_1=self.attn_metadata.page_table_1,
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
topk=topk,
)
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
return torch.nn.functional.pad(
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
)
_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
]
NSA_PREFILL_IMPL: _NSA_IMPL_T
NSA_DECODE_IMPL: _NSA_IMPL_T
class NativeSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata: NSAMetadata
self.device = model_runner.device
assert isinstance(model_runner.page_size, int)
self.real_page_size = model_runner.page_size
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
self.nsa_kv_cache_store_fp8 = (
model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
)
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
self.max_context_len = model_runner.model_config.context_len
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
assert model_runner.req_to_token_pool is not None
self.req_to_token = model_runner.req_to_token_pool.req_to_token
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
if _is_hip:
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
def get_device_int32_arange(self, l: int) -> torch.Tensor:
if l > len(self._arange_buf):
next_pow_of_2 = 1 << (l - 1).bit_length()
self._arange_buf = torch.arange(
next_pow_of_2, device=self.device, dtype=torch.int32
)
return self._arange_buf[:l]
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
page_size = self.real_page_size
if page_size == 1:
return page_table
max_seqlen_k = page_table.shape[1]
strided_indices = torch.arange(
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
)
return page_table[:, strided_indices] // page_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
batch_size = forward_batch.batch_size
device = forward_batch.seq_lens.device
assert (
forward_batch.spec_info is None
), "Spec decoding is not supported for NSA backend now"
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
assert forward_batch.seq_lens_cpu is not None
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :max_seqlen_k
]
if forward_batch.forward_mode.is_decode_or_idle():
extend_seq_lens_cpu = [1] * batch_size
max_seqlen_q = 1
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
seqlens_expanded = cache_seqlens_int32
elif forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.extend_seq_lens is not None
and forward_batch.extend_prefix_lens_cpu is not None
), "All of them must not be None"
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
assert forward_batch.extend_seq_lens is not None
if any(forward_batch.extend_prefix_lens_cpu):
max_seqlen_q = max(extend_seq_lens_cpu)
cu_seqlens_q = compute_cu_seqlens(
forward_batch.extend_seq_lens.to(torch.int32)
)
else:
max_seqlen_q = max_seqlen_k
cu_seqlens_q = cu_seqlens_k
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=device,
)
for qo_len, kv_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.seq_lens_cpu.tolist(),
strict=True,
)
]
)
else:
assert False, f"Unsupported {forward_batch.forward_mode = }"
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
original_seq_lens=seqlens_expanded,
nsa_index_topk=self.nsa_index_topk,
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table,
flashmla_metadata=(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=seqlens_expanded,
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
real_page_table=self._transform_table_1_to_real(page_table),
)
self.forward_metadata = metadata
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self.decode_cuda_graph_metadata: Dict = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
# fake page_table for sparse_prefill
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"flashmla_metadata": (
self._compute_flashmla_metadata(
cache_seqlens=torch.ones(
max_bs, dtype=torch.int32, device=self.device
),
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
):
"""Initialize forward metadata for capturing CUDA graph."""
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
# Normal Decode
# Get sequence information
cache_seqlens_int32 = seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
# Use max context length for seq_len_k
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
max_seq_len_k = page_table_1.shape[1]
# Precompute page table
# Precompute cumulative sequence lengths
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
real_page_table = self._transform_table_1_to_real(page_table_1)
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs + 1))
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
else:
flashmla_metadata = None
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=1,
max_seq_len_k=max_seq_len_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table_1,
flashmla_metadata=flashmla_metadata,
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=cache_seqlens_int32,
real_page_table=real_page_table,
nsa_extend_seq_lens_list=[1] * bs,
)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert seq_lens_cpu is not None
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
# Normal Decode
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
max_len = int(seq_lens_cpu.max().item())
cache_seqlens = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_len]
metadata.page_table_1[:, :max_len].copy_(page_indices)
assert (
metadata.nsa_cache_seqlens_int32 is not None
and metadata.nsa_cu_seqlens_k is not None
and self.nsa_index_topk is not None
)
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
metadata.nsa_cu_seqlens_k[1:].copy_(
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
)
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
assert self.real_page_size == metadata.page_size
if self.real_page_size > 1:
real_table = self._transform_table_1_to_real(page_indices)
new_len = real_table.shape[1]
metadata.real_page_table[:, :new_len].copy_(real_table)
else:
assert metadata.real_page_table is metadata.page_table_1
if NSA_DECODE_IMPL == "flashmla_decode":
metadata.flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
self.forward_metadata = metadata
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert (
not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
), "NSA backend doesn't support speculative decoding"
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
# Do absorbed multi-latent attention
assert q_rope is not None
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# when store in fp8 and compute in fp8, no need to convert dtype
if not (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and self.nsa_kv_cache_store_fp8
):
kv_cache = kv_cache.to(q.dtype)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
# NOTE(dark): here, we use page size = 1
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
assert metadata.nsa_extend_seq_lens_list is not None
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
if NSA_PREFILL_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
metadata=metadata,
page_table_1=page_table_1,
)
elif NSA_PREFILL_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
else:
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
page_size=1,
)
if NSA_DECODE_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
metadata=metadata,
page_table_1=page_table_1,
)
elif NSA_DECODE_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
elif NSA_DECODE_IMPL == "aiter":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_aiter(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
layer=layer,
metadata=metadata,
bs=forward_batch.batch_size,
)
else:
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
def _forward_fa3(
self,
q_rope: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
q_nope: torch.Tensor,
page_table: torch.Tensor,
cache_seqlens: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
sm_scale: float,
logit_cap: float,
page_size: int,
) -> torch.Tensor:
k_rope_cache = kv_cache[:, :, v_head_dim:]
c_kv_cache = kv_cache[:, :, :v_head_dim]
qk_rope_dim = k_rope_cache.shape[-1]
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=sm_scale,
causal=True,
softcap=logit_cap,
return_softmax_lse=False,
num_splits=self.num_splits,
)
return o # type: ignore
def _forward_flashmla_prefill(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
from flash_mla import flash_mla_sparse_fwd
o, _, _ = flash_mla_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
return o
def _forward_flashmla_decode(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
sm_scale: float,
layer,
metadata: NSAMetadata,
page_table_1,
) -> torch.Tensor:
from flash_mla import flash_mla_with_kvcache
cache_seqlens = metadata.nsa_cache_seqlens_int32
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
assert self.real_page_size == 64, "only page size 64 is supported"
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not self.nsa_kv_cache_store_fp8:
# inefficiently quantize the whole cache
kv_cache = quantize_k_cache(kv_cache)
indices = page_table_1.unsqueeze(1)
assert (
indices.shape[-1] == self.nsa_index_topk
) # requirement of FlashMLA decode kernel
o, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv_cache,
cache_seqlens=cache_seqlens,
head_dim_v=v_head_dim,
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
num_splits=metadata.flashmla_metadata.num_splits,
softmax_scale=sm_scale,
indices=indices,
# doc says it is not used, but if pass in None then error
block_table=torch.empty(
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
),
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
)
return o
def _forward_tilelang(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
return tilelang_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
def _forward_aiter(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
page_table_1: torch.Tensor,
layer: RadixAttention,
metadata: NSAMetadata,
bs: int,
) -> torch.Tensor:
q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)
if layer.head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
kv_indptr = self.kv_indptr
non_minus1_mask = page_table_1 != -1
non_minus1_counts = non_minus1_mask.sum(dim=1)
kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0)
kv_indices = page_table_1[page_table_1 != -1]
mla_decode_fwd(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
kv_cache.view(-1, 1, 1, layer.head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
metadata.cu_seqlens_q,
kv_indptr,
kv_indices,
metadata.cu_seqlens_q,
metadata.max_seq_len_q,
layer.scaling,
layer.logit_cap,
)
# kv_cache = kv_cache.view(-1, 1, layer.head_dim)
return o
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> NSAIndexerMetadata:
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
from flash_mla import get_mla_metadata
flashmla_metadata, num_splits = get_mla_metadata(
cache_seqlens=cache_seqlens,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
num_heads_k=1,
num_heads_q=self.num_q_heads,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
topk=self.nsa_index_topk,
)
return NSAFlashMLAMetadata(
flashmla_metadata=flashmla_metadata,
num_splits=num_splits,
)
......@@ -813,45 +813,69 @@ class DeepEPMoE(EPMoE):
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
else:
# dynamic quant
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
......@@ -860,47 +884,72 @@ class DeepEPMoE(EPMoE):
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
......
......@@ -112,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
]
# Put some global args for easy access
......
......@@ -76,35 +76,49 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if self.need_sort and num_new_pages > len(self.free_pages):
num_new_pages = (
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
).sum()
num_new_pages_item = num_new_pages.item()
if self.need_sort and num_new_pages_item > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
if num_new_pages_item > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if num_new_pages_item < 200:
import sgl_kernel_npu
torch.ops.npu.alloc_extend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
self.page_size,
out_indices,
num_new_pages,
)
else:
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
self.free_pages = self.free_pages[num_new_pages_item:]
return out_indices
def alloc_decode(
......
......@@ -15,6 +15,8 @@ limitations under the License.
from __future__ import annotations
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
......@@ -1030,6 +1032,8 @@ class MLATokenToKVPool(KVCache):
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
use_nsa: bool = False,
override_kv_cache_dim: Optional[int] = None,
):
super().__init__(
size,
......@@ -1044,6 +1048,14 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.use_nsa = use_nsa
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
# TODO do not hardcode
self.kv_cache_dim = (
656
if self.use_nsa and self.nsa_kv_cache_store_fp8
else (kv_lora_rank + qk_rope_head_dim)
)
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
......@@ -1067,7 +1079,7 @@ class MLATokenToKVPool(KVCache):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
(size + page_size, 1, self.kv_cache_dim),
dtype=self.store_dtype,
device=device,
)
......@@ -1130,6 +1142,7 @@ class MLATokenToKVPool(KVCache):
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
......@@ -1147,16 +1160,28 @@ class MLATokenToKVPool(KVCache):
cache_k_rope: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
)
if self.use_nsa and self.nsa_kv_cache_store_fp8:
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
# TODO no need to cat
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
cache_k = cache_k.view(self.store_dtype)
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
else:
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer],
loc,
cache_k_nope,
cache_k_rope,
)
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
......@@ -1186,6 +1211,103 @@ class MLATokenToKVPool(KVCache):
torch.cuda.synchronize()
class NSATokenToKVPool(MLATokenToKVPool):
def __init__(
self,
size: int,
page_size: int,
kv_lora_rank: int,
dtype: torch.dtype,
qk_rope_head_dim: int,
layer_num: int,
device: str,
index_head_dim: int,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super().__init__(
size,
page_size,
dtype,
kv_lora_rank,
qk_rope_head_dim,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
use_nsa=True,
)
# self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_scale_dtype = torch.float32
self.index_head_dim = index_head_dim
# num head == 1 and head dim == 128 for index_k in NSA
assert index_head_dim == 128
self.quant_block_size = 128
assert self.page_size == 64
self.index_k_with_scale_buffer = [
torch.zeros(
# Layout:
# ref: test_attention.py :: kv_cache_cast_to_fp8
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
# data: for page i,
# * buf[i, :page_size * head_dim] for fp8 data
# * buf[i, page_size * head_dim:].view(float32) for scale
(
(size + page_size + 1) // self.page_size,
self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
),
dtype=torch.uint8,
device=device,
)
for _ in range(layer_num)
]
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
def get_index_k_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetK.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
def get_index_k_scale_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetS.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
# TODO rename later (currently use diff name to avoid confusion)
def set_index_k_and_scale_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
) -> None:
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
index_buf_accessor.SetKAndS.execute(
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
)
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
self,
......@@ -1194,6 +1316,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
index_head_dim: Optional[int],
layer_num: int,
device: str,
enable_memory_saver: bool,
......@@ -1213,6 +1336,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.index_head_dim = index_head_dim
self.custom_mem_pool = None
......@@ -1240,6 +1364,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype=self.store_dtype,
device=self.device,
)
if self.index_head_dim is not None:
self.index_k_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.index_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
self._finalize_allocation_log(size)
......@@ -1251,6 +1387,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_size_bytes += get_tensor_size_bytes(k_cache)
for v_cache in self.v_buffer:
kv_size_bytes += get_tensor_size_bytes(v_cache)
if self.index_head_dim is not None:
assert hasattr(self, "index_k_buffer")
for index_k_cache in self.index_k_buffer:
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):
......@@ -1277,6 +1417,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
def get_index_k_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.index_k_buffer[layer_id - self.start_layer]
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
......@@ -1289,6 +1437,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
]
if self.index_head_dim is not None:
kv_data_ptrs += [
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
]
kv_data_lens += [
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
]
kv_item_lens += [
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
......@@ -1325,6 +1483,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
cache_v.view(-1, 1, self.qk_rope_head_dim),
)
def set_index_k_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
):
if index_k.dtype != self.dtype:
index_k = index_k.to(self.dtype)
if self.store_dtype != self.dtype:
index_k = index_k.view(self.store_dtype)
torch_npu.npu_scatter_nd_update_(
self.index_k_buffer[layer_id - self.start_layer].view(
-1, 1, self.index_head_dim
),
loc.view(-1, 1),
index_k.view(-1, 1, self.index_head_dim),
)
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
......
......@@ -522,6 +522,7 @@ class CudaGraphRunner:
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
seq_lens_cpu = self.seq_lens_cpu[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
if self.is_encoder_decoder:
......@@ -592,6 +593,7 @@ class CudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
next_token_logits_buffer=next_token_logits_buffer,
orig_seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment