Unverified Commit efbc687c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files
parent 292a867a
......@@ -49,6 +49,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS = "transformers"
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
and getattr(config, "index_topk", None) is not None
)
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_head_dim
def get_nsa_index_topk(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_topk
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_n_heads
class ModelConfig:
def __init__(
self,
......@@ -271,6 +295,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
......@@ -283,6 +308,11 @@ class ModelConfig:
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.index_head_dim = (
get_nsa_index_head_dim(self.hf_config)
if is_deepseek_nsa(self.hf_config)
else None
)
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
......
......@@ -2,9 +2,19 @@ import logging
import os
from typing import List, Optional
import torch
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
try:
from mf_adapter import TransferEngine
import_error = None
except ImportError as e:
import_error = e
pass
logger = logging.getLogger(__name__)
......@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
):
try:
from mf_adapter import TransferEngine
except ImportError as e:
raise ImportError(
if import_error is not None:
logger.warning(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
) from e
)
raise import_error
self.engine = TransferEngine()
self.hostname = hostname
......@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
self.initialize()
def initialize(self) -> None:
from sglang.srt.layers.dp_attention import (
get_tensor_model_parallel_world_size,
get_tp_group,
)
transfer_protocol = self._get_transfer_protocol()
if transfer_protocol is None or transfer_protocol == "sdma":
trans_op_type = TransferEngine.TransDataOpType.SDMA
else:
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
"""with device RDMA for PD transfer"""
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(get_tensor_model_parallel_world_size())
]
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
torch.distributed.all_gather(
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
)
"""Initialize the ascend transfer instance."""
ret_value = self.engine.initialize(
self.store_url,
self.session_id,
self.role,
self.npu_id,
self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
)
if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.")
......@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
ret_value = -1
if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
@staticmethod
def _get_transfer_protocol():
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
allowed_protocols = {"device_rdma", "sdma"}
if protocol and protocol.lower() in allowed_protocols:
return protocol.lower()
else:
logger.warning(
"Invalid or no transfer protocol specified, using default protocol."
)
return None
......@@ -36,6 +36,8 @@ class ForwardMetadata:
seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None
seq_lens: Optional[torch.Tensor] = None
actual_seq_lengths_q: Optional[torch.Tensor] = None
class AscendAttnBackend(AttentionBackend):
......@@ -67,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.q_head_dim = (
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
)
self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len
......@@ -102,10 +107,6 @@ class AscendAttnBackend(AttentionBackend):
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
if forward_batch.is_extend_in_batch:
seq_lens_list_cumsum[-1] = (
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False
......@@ -133,6 +134,10 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
metadata.seq_lens = seq_lens
metadata.actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
)
self.graph_metadata[bs] = metadata
self.forward_metadata = metadata
......@@ -161,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0)
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
self.forward_metadata = metadata
self.graph_mode = True
......@@ -168,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_sparse(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: torch.Tensor = None,
):
is_prefill = forward_batch.forward_mode.is_extend()
if save_kv_cache:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
q_nope, q_pe = q, q_rope
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
block_table = self.forward_metadata.block_tables
if is_prefill:
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
else:
if self.forward_metadata.actual_seq_lengths_q is None:
actual_seq_qlen = (
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
)
else:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
attn_out = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
query_rope=q_pe,
key_rope=k_pe,
sparse_indices=topk_indices,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
block_table=block_table,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
return attn_out
def forward_extend(
self,
q,
......@@ -176,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
......@@ -437,10 +518,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache
save_kv_cache = False
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if self.graph_mode:
return self.forward_decode_graph(
......
......@@ -66,6 +66,13 @@ def create_ascend_backend(runner):
return AscendAttnBackend(runner)
@register_attention_backend("nsa")
def create_nsa_backend(runner):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(runner)
@register_attention_backend("triton")
def create_triton_backend(runner):
assert not runner.model_config.is_encoder_decoder, (
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
......@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def support_triton(self):
"""Check if the current backend supports triton."""
return True
def get_indexer_metadata(
self,
layer_id: int,
forward_batch: ForwardBatch,
) -> Optional[BaseIndexerMetadata]:
"""Get the indexer metadata. None means don't support indexer."""
return None
......@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> Optional[BaseIndexerMetadata]:
backend = self._select_backend(forward_batch.forward_mode)
return backend.get_indexer_metadata(layer_id, forward_batch)
......@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self.rotary_emb = rotary_emb
self.layer_id = layer_id
self.has_preprocess_weights = False
self.dtype = None
self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
def preprocess_weights(self, hidden_states):
self.dummy = torch.empty(
......@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
def forward_absorb_prepare_npu_rms_norm_cache(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch,
zero_allocator,
):
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
self.dtype = hidden_states.dtype
self.cos, self.sin = self.get_sin_cos(positions)
self.kvCache, self.kvCacheRope, self.slotmapping = (
self.get_kv_cache_and_cache_idx(forward_batch)
)
if not self.has_preprocess_weights:
self.has_preprocess_weights = True
cos, sin = self.cos, self.sin
if self.q_lora_rank is not None:
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q_lowrank)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # b*s,n,d
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
latent_cache = latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (B*S,N,1,D)
cache_mode = "PA_BNSD"
self.kvCache = self.kvCache.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.kv_lora_rank,
)
self.kvCacheRope = self.kvCacheRope.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.qk_rope_head_dim,
)
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
latent_cache,
self.kv_a_layernorm.weight,
cos,
sin,
self.slotmapping.to(torch.int64),
self.kvCacheRope,
self.kvCache,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype
if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states)
......@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator,
positions,
)
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
_is_w8a8 = (
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
and self.qkv_a_proj.quant_method.quantization_config.get_name()
== "w8a8_int8"
)
if _is_w8a8:
return self.forward_mlapo(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb_prepare_npu_rms_norm_cache(
positions, hidden_states, forward_batch, zero_allocator
)
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
def dequantize_k_cache(quant_k_cache):
if NSA_DEQUANT_K_CACHE_FAST:
return _dequantize_k_cache_fast_wrapped(quant_k_cache)
else:
return _dequantize_k_cache_slow(quant_k_cache)
def _dequantize_k_cache_slow(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576,
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty(
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_nope * cur_scales
)
result = result.view(num_blocks, block_size, 1, d)
return result
def _dequantize_k_cache_fast_wrapped(
quant_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
assert dv == 512
assert dim_quant == 656
assert tile_size == 128
quant_k_cache = quant_k_cache.view((-1, dim_quant))
output = _dequantize_k_cache_fast(quant_k_cache)
return output.view(num_blocks, block_size, 1, -1)
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
num_tokens, dim_quant = quant_k_cache.shape
assert quant_k_cache.dtype == torch.float8_e4m3fn
dim_nope = 512
dim_rope = 64
num_tiles = dim_nope // group_size
assert dim_quant == 656
output = torch.empty(
(num_tokens, dim_nope + dim_rope),
dtype=torch.bfloat16,
device=quant_k_cache.device,
)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
input_nope_q = quant_k_cache[:, :dim_nope]
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
torch.float32
)
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
_dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output,
input_nope_q,
input_nope_s,
input_rope,
output.stride(0),
input_nope_q.stride(0),
input_nope_s.stride(0),
input_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
)
return output
@triton.jit
def _dequantize_k_cache_fast_kernel(
output_ptr,
input_nope_q_ptr,
input_nope_s_ptr,
input_rope_ptr,
output_stride_0: int,
input_nope_q_stride_0: int,
input_nope_s_stride_0: int,
input_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. dequant nope
effective_block_id = raw_block_id
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs_q < DIM_NOPE
ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(ptr_s)
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
tl.store(dst_ptr, y, mask=mask)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
raise Exception("UT is in quant_k_cache.py")
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
"""
k: data, 128 item per token, fp8
s: scale, 1 item per token, fp32
"""
class GetK:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
index_k_fp8 = torch.empty(
(seq_len_, pool.index_head_dim),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
return index_k_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim), uint8
"""
# can handle per 128B instead of per element
# page_indices: (num_pages,), element := a page index
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_page = pool.page_size * pool.index_head_dim
num_k_bytes_per_token = pool.index_head_dim
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
# flat_buf: (whatever,), uint8
flat_buf = buf.flatten()
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
num_k_bytes_per_page, dtype=torch.int32, device="cuda"
)[None, :]
flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 128)
class GetS:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
assert pool.index_head_dim // pool.quant_block_size == 1
index_k_scale_fp8 = torch.empty(
(seq_len_, 4),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][pool.page_size * pool.index_head_dim :].view(-1, 4)
return index_k_scale_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim // quant_block_size), uint8
"""
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
s_offset_in_page = pool.page_size * pool.index_head_dim
flat_buf = buf.flatten()
flat_indices = (
(page_indices * buf_numel_per_page)[:, None]
+ torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
None, :
]
+ s_offset_in_page
)
flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 4)
class SetK:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
buf[
page_index,
offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
] = index_k[i].view(torch.uint8)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_token = pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
+ torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
flat_indices = flat_indices.flatten()[:num_k_bytes_total]
flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
class SetS:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
start = pool.page_size * pool.index_head_dim
buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
index_k_scale[i].view(torch.uint8)
)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_token = 4
s_offset_in_page = pool.page_size * pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ s_offset_in_page
+ (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
+ torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
flat_indices = flat_indices.flatten()[:number_s_bytes_total]
flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
class SetKAndS:
@classmethod
def execute(cls, *args, buf, **kwargs):
if 0:
# print("SetK, SetS comparison test")
buf_cloned = buf.clone()
cls.vanilla(*args, **kwargs, buf=buf)
cls.triton(*args, **kwargs, buf=buf_cloned)
def _clear_token_0(target):
target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
_clear_token_0(buf)
_clear_token_0(buf_cloned)
assert torch.all(
buf == buf_cloned
), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
return
cls.triton(*args, **kwargs, buf=buf)
@classmethod
def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
@classmethod
def triton(cls, pool, buf, loc, index_k, index_k_scale):
_set_k_and_s_triton(
buf=buf,
loc=loc,
index_k=index_k,
index_k_scale=index_k_scale,
page_size=pool.page_size,
)
def _set_k_and_s_triton(
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
page_size: int,
):
"""
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
:param loc: (num_tokens_to_write,), int, element := the token index to write to
:param index_k: (num_tokens_to_write, 128 elem), fp8
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
:return:
"""
num_pages, buf_numel_per_page = buf.shape
(num_tokens_to_write,) = loc.shape
num_tokens_to_write_, index_head_dim = index_k.shape
num_tokens_to_write__, scale_dim = index_k_scale.shape
assert buf_numel_per_page == 64 * (128 + 4)
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
assert index_head_dim == 128
assert scale_dim == 1
assert page_size == 64
assert buf.dtype == torch.uint8
assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
assert index_k.dtype == torch.float8_e4m3fn
assert index_k_scale.dtype == torch.float32
assert buf.is_contiguous()
assert loc.is_contiguous()
assert index_k.is_contiguous()
assert index_k_scale.is_contiguous()
buf_fp8 = buf.view(torch.float8_e4m3fn)
buf_fp32 = buf.view(torch.float32)
_set_k_and_s_triton_kernel[(num_tokens_to_write,)](
buf_fp8,
buf_fp32,
loc,
index_k,
index_k_scale,
index_k.stride(0),
PAGE_SIZE=page_size,
BUF_NUMEL_PER_PAGE=buf_numel_per_page,
NUM_K_ELEMS_PER_TOKEN=index_head_dim,
S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
)
@triton.jit
def _set_k_and_s_triton_kernel(
buf_fp8_ptr,
buf_fp32_ptr,
loc_ptr,
index_k_ptr,
index_k_scale_ptr,
index_k_ptr_stride_0,
PAGE_SIZE: tl.constexpr,
BUF_NUMEL_PER_PAGE: tl.constexpr,
NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
):
token_id = tl.program_id(0)
loc = tl.load(loc_ptr + token_id)
in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
k = tl.load(index_k_ptr + in_k_offsets)
k_scale = tl.load(index_k_scale_ptr + token_id)
loc_page_index = loc // PAGE_SIZE
loc_token_offset_in_page = loc % PAGE_SIZE
out_k_offsets = (
loc_page_index * BUF_NUMEL_PER_PAGE
+ loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
+ tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
)
# "//4" b/c it is fp32 instead of uint8
out_s_offset = (
loc_page_index * BUF_NUMEL_PER_PAGE // 4
+ S_OFFSET_NBYTES_IN_PAGE // 4
+ loc_token_offset_in_page
)
tl.store(buf_fp8_ptr + out_k_offsets, k)
tl.store(buf_fp32_ptr + out_s_offset, k_scale)
This diff is collapsed.
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
def quantize_k_cache(cache_k):
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
if NSA_QUANT_K_CACHE_FAST:
return _quantize_k_cache_fast_wrapped(cache_k)
else:
return _quantize_k_cache_slow(cache_k)
# Copied from original
def _quantize_k_cache_slow(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty(
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
dtype=torch.float8_e4m3fn,
device=input_k_cache.device,
)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = (
torch.abs(
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
)
.max(dim=-1)
.values
/ 448.0
) # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (
input_k_cache[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].float()
/ cur_scale_factors_inv.float()
).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_quantized_nope
)
result = result.view(num_blocks, block_size, 1, -1)
return result
def _quantize_k_cache_fast_wrapped(
input_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
assert dv == 512
assert dim_nope_and_rope == 512 + 64
assert tile_size == 128
input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
k_nope = input_k_cache[:, :dv]
k_rope = input_k_cache[:, dv:]
output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
return output.view(num_blocks, block_size, 1, -1)
def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
"""
:param k_nope: (num_tokens, dim_nope 512)
:param k_rope: (num_tokens, dim_rope 64)
"""
assert k_nope.dtype == torch.bfloat16
assert k_rope.dtype == torch.bfloat16
num_tokens, dim_nope = k_nope.shape
num_tokens_, dim_rope = k_rope.shape
assert num_tokens == num_tokens_
assert dim_nope == 512
assert dim_rope == 64
assert k_nope.dtype == k_rope.dtype
num_tiles = dim_nope // group_size
assert k_nope.stride(1) == 1
assert k_rope.stride(1) == 1
output = torch.empty(
(num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
dtype=torch.float8_e4m3fn,
device=k_nope.device,
)
output_nope_q = output[..., :dim_nope]
output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
_quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output_nope_q,
output_nope_s,
output_rope,
k_nope,
k_rope,
output_nope_q.stride(0),
output_nope_s.stride(0),
output_rope.stride(0),
k_nope.stride(0),
k_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
)
return output
@triton.jit
def _quantize_k_cache_fast_kernel(
output_nope_q_ptr,
output_nope_s_ptr,
output_rope_ptr,
k_nope_ptr,
k_rope_ptr,
output_nope_q_stride_0: int,
output_nope_s_stride_0: int,
output_rope_stride_0: int,
k_nope_stride_0: int,
k_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
FP8_MIN: tl.constexpr,
FP8_MAX: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. quant nope
effective_block_id = raw_block_id
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_NOPE
ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
y_s = tl.max(tl.abs(y)) / FP8_MAX
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
output_nope_q_ptr.dtype.element_ty
)
dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
dst_s_ptr = (
output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
)
tl.store(dst_q_ptr, y_q, mask=mask)
tl.store(dst_s_ptr, y_s)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
data = tl.load(src_ptr, mask=mask)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
for num_blocks, block_size in [
(1, 1),
(10, 64),
]:
dim_nope_and_rope = 512 + 64
input_k_cache = torch.randn(
(num_blocks, block_size, 1, dim_nope_and_rope),
dtype=torch.bfloat16,
device="cuda",
)
# temp debug
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
ref_quant = _quantize_k_cache_slow(input_k_cache)
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
# print(f"{input_k_cache=}")
# print(f"{ref_quant=}")
# print(f"{actual_quant=}")
# print(f"{ref_quant == actual_quant=}")
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
# print(f"{ref_quant.view(torch.bfloat16)=}")
# print(f"{actual_quant.view(torch.bfloat16)=}")
# assert torch.all(ref_quant == actual_quant)
import dequant_k_cache
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
actual_quant
)
print(f"{ref_ref_dequant=}")
print(f"{actual_actual_dequant=}")
print(f"{actual_actual_dequant - ref_ref_dequant=}")
print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
# TODO too different?
torch.testing.assert_close(
ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
)
torch.testing.assert_close(
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
)
print("Passed")
This diff is collapsed.
from typing import List, Optional
import torch
import triton
import triton.language as tl
def transform_index_page_table_prefill(**kwargs):
return transform_index_page_table_prefill_ref(**kwargs)
def transform_index_page_table_decode(**kwargs):
return transform_index_page_table_decode_ref(**kwargs)
@triton.jit
def transform_index_page_table_decode_kernel(
page_table_ptr: torch.Tensor,
topk_indices_ptr: torch.Tensor,
result_ptr: torch.Tensor,
page_size: tl.constexpr,
max_seqlen_k: tl.constexpr,
):
TOPK: tl.constexpr = 2048
req_id = tl.program_id(0)
page_table_ptr = page_table_ptr + req_id * max_seqlen_k
topk_indices_ptr = topk_indices_ptr + req_id * TOPK
result_ptr = result_ptr + req_id * TOPK
offset = tl.arange(0, TOPK) # topk should be 2048
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
mask = loaded_topk_indices >= 0
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
tl.store(result_ptr + offset, -1, mask=~mask)
def transform_index_page_table_decode_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
"""
Transform the page table according to topk indices for sparse topk attention.
Args:
page_table: [qo_len, max_seqlen_k], the original page table
topk_indices: [qo_len, topk], the topk indices for each query position
Returns:
transformed_page_table: [qo_len, topk], the transformed page table
For out-of-bound indices in topk_indices, this should be filled with -1.
"""
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
assert topk_indices.shape[1] == 2048
qo_len = topk_indices.shape[0]
max_seqlen_k = page_table.shape[1]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
# Launch triton kernel
grid = (qo_len,)
transform_index_page_table_decode_kernel[grid](
page_table,
topk_indices,
result,
page_size,
max_seqlen_k=max_seqlen_k,
)
return result
def transform_index_page_table_prefill_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
# TODO(baizhou): can be implemented with another triton kernel
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_fast(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
def transform_index_page_table_decode_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert result.shape == topk_indices.shape
torch.gather(
page_table,
dim=1,
index=topk_indices.clamp(min=0),
out=result,
)
result[topk_indices < 0] = -1
return result
def transform_index_page_table_prefill_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_ref(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
if __name__ == "__main__":
bs, topk, max_seqlen = 10, 2048, 3000
page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
topk_indices = torch.full((bs, topk), -1, device="cuda")
topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
result = transform_index_page_table_decode_fast(page_table, topk_indices)
assert torch.all(result == ref_result)
print("Passed")
# temp NSA debugging environ
from sglang.srt.utils import get_bool_env_var
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
)
NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "true")
NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "true")
def print_nsa_bool_env_vars():
msg = ""
for k, v in globals().items():
if k.startswith("NSA_") and isinstance(v, bool):
msg += f"{k}={v} "
print(msg, flush=True)
def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
return original_seq_lens.clamp(max=nsa_index_topk)
This diff is collapsed.
......@@ -813,45 +813,69 @@ class DeepEPMoE(EPMoE):
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
else:
# dynamic quant
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
......@@ -860,47 +884,72 @@ class DeepEPMoE(EPMoE):
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
......
......@@ -112,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
]
# Put some global args for easy access
......
......@@ -76,35 +76,49 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if self.need_sort and num_new_pages > len(self.free_pages):
num_new_pages = (
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
).sum()
num_new_pages_item = num_new_pages.item()
if self.need_sort and num_new_pages_item > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
if num_new_pages_item > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if num_new_pages_item < 200:
import sgl_kernel_npu
torch.ops.npu.alloc_extend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
self.page_size,
out_indices,
num_new_pages,
)
else:
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
self.free_pages = self.free_pages[num_new_pages_item:]
return out_indices
def alloc_decode(
......
......@@ -15,6 +15,8 @@ limitations under the License.
from __future__ import annotations
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
......@@ -1030,6 +1032,8 @@ class MLATokenToKVPool(KVCache):
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
use_nsa: bool = False,
override_kv_cache_dim: Optional[int] = None,
):
super().__init__(
size,
......@@ -1044,6 +1048,14 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.use_nsa = use_nsa
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
# TODO do not hardcode
self.kv_cache_dim = (
656
if self.use_nsa and self.nsa_kv_cache_store_fp8
else (kv_lora_rank + qk_rope_head_dim)
)
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
......@@ -1067,7 +1079,7 @@ class MLATokenToKVPool(KVCache):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
(size + page_size, 1, self.kv_cache_dim),
dtype=self.store_dtype,
device=device,
)
......@@ -1130,6 +1142,7 @@ class MLATokenToKVPool(KVCache):
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
......@@ -1147,16 +1160,28 @@ class MLATokenToKVPool(KVCache):
cache_k_rope: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
)
if self.use_nsa and self.nsa_kv_cache_store_fp8:
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
# TODO no need to cat
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
cache_k = cache_k.view(self.store_dtype)
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
else:
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer],
loc,
cache_k_nope,
cache_k_rope,
)
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
......@@ -1186,6 +1211,103 @@ class MLATokenToKVPool(KVCache):
torch.cuda.synchronize()
class NSATokenToKVPool(MLATokenToKVPool):
def __init__(
self,
size: int,
page_size: int,
kv_lora_rank: int,
dtype: torch.dtype,
qk_rope_head_dim: int,
layer_num: int,
device: str,
index_head_dim: int,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super().__init__(
size,
page_size,
dtype,
kv_lora_rank,
qk_rope_head_dim,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
use_nsa=True,
)
# self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_scale_dtype = torch.float32
self.index_head_dim = index_head_dim
# num head == 1 and head dim == 128 for index_k in NSA
assert index_head_dim == 128
self.quant_block_size = 128
assert self.page_size == 64
self.index_k_with_scale_buffer = [
torch.zeros(
# Layout:
# ref: test_attention.py :: kv_cache_cast_to_fp8
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
# data: for page i,
# * buf[i, :page_size * head_dim] for fp8 data
# * buf[i, page_size * head_dim:].view(float32) for scale
(
(size + page_size + 1) // self.page_size,
self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
),
dtype=torch.uint8,
device=device,
)
for _ in range(layer_num)
]
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
def get_index_k_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetK.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
def get_index_k_scale_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetS.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
# TODO rename later (currently use diff name to avoid confusion)
def set_index_k_and_scale_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
) -> None:
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
index_buf_accessor.SetKAndS.execute(
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
)
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
self,
......@@ -1194,6 +1316,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
index_head_dim: Optional[int],
layer_num: int,
device: str,
enable_memory_saver: bool,
......@@ -1213,6 +1336,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.index_head_dim = index_head_dim
self.custom_mem_pool = None
......@@ -1240,6 +1364,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype=self.store_dtype,
device=self.device,
)
if self.index_head_dim is not None:
self.index_k_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.index_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
self._finalize_allocation_log(size)
......@@ -1251,6 +1387,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_size_bytes += get_tensor_size_bytes(k_cache)
for v_cache in self.v_buffer:
kv_size_bytes += get_tensor_size_bytes(v_cache)
if self.index_head_dim is not None:
assert hasattr(self, "index_k_buffer")
for index_k_cache in self.index_k_buffer:
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):
......@@ -1277,6 +1417,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
def get_index_k_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.index_k_buffer[layer_id - self.start_layer]
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
......@@ -1289,6 +1437,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
]
if self.index_head_dim is not None:
kv_data_ptrs += [
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
]
kv_data_lens += [
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
]
kv_item_lens += [
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
......@@ -1325,6 +1483,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
cache_v.view(-1, 1, self.qk_rope_head_dim),
)
def set_index_k_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
):
if index_k.dtype != self.dtype:
index_k = index_k.to(self.dtype)
if self.store_dtype != self.dtype:
index_k = index_k.view(self.store_dtype)
torch_npu.npu_scatter_nd_update_(
self.index_k_buffer[layer_id - self.start_layer].view(
-1, 1, self.index_head_dim
),
loc.view(-1, 1),
index_k.view(-1, 1, self.index_head_dim),
)
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
......
......@@ -522,6 +522,7 @@ class CudaGraphRunner:
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
seq_lens_cpu = self.seq_lens_cpu[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
if self.is_encoder_decoder:
......@@ -592,6 +593,7 @@ class CudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
next_token_logits_buffer=next_token_logits_buffer,
orig_seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment