from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Optional import torch import torch.nn.functional as F from einops import rearrange from torch import nn from sglang.srt.custom_op import CustomOp from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu if is_cuda(): try: import deep_gemm except ImportError as e: deep_gemm = e from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0 class BaseIndexerMetadata(ABC): @abstractmethod def get_seqlens_int32(self) -> torch.Tensor: """ Return: (batch_size,) int32 tensor """ @abstractmethod def get_page_table_64(self) -> torch.Tensor: """ Return: (batch_size, num_blocks) int32, page table. The page size of the table is 64. """ @abstractmethod def get_seqlens_expanded(self) -> torch.Tensor: """ Return: (sum_extend_seq_len,) int32 tensor """ @abstractmethod def topk_transform( self, logits: torch.Tensor, topk: int, ) -> torch.Tensor: """ Perform topk selection on the logits and possibly transform the result. NOTE that attention backend may override this function to do some transformation, which means the result of this topk_transform may not be the topk indices of the input logits. Return: Anything, since it will be passed to the attention backend for further processing on sparse attention computation. Don't assume it is the topk indices of the input logits. """ def rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 from fast_hadamard_transform import hadamard_transform hidden_size = x.size(-1) assert ( hidden_size & (hidden_size - 1) ) == 0, "Hidden size must be a power of 2 for Hadamard transform." return hadamard_transform(x, scale=hidden_size**-0.5) class V32LayerNorm(nn.Module): """ Layer Normalization. """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) def forward(self, x: torch.Tensor): return F.layer_norm( x.float(), (self.dim,), self.weight, self.bias, self.eps ).type_as(x) class Indexer(CustomOp): def __init__( self, hidden_size: int, index_n_heads: int, index_head_dim: int, rope_head_dim: int, index_topk: int, q_lora_rank: int, max_position_embeddings: int, rope_theta: float, layer_id: int, scale_fmt: Optional[str], block_size: int = 128, rope_scaling: Optional[Dict[str, Any]] = None, prefix: str = "", quant_config: Optional[QuantizationConfig] = None, alt_stream: Optional[torch.cuda.Stream] = None, ): super().__init__() self.hidden_size = hidden_size self.n_heads = index_n_heads self.head_dim = index_head_dim self.rope_head_dim = rope_head_dim self.index_topk = index_topk self.q_lora_rank = q_lora_rank self.layer_id = layer_id self.alt_stream = alt_stream if is_cuda(): self.sm_count = deep_gemm.get_num_sms() self.half_device_sm_count = align(self.sm_count // 2, 8) self.wq_b = ReplicatedLinear( self.q_lora_rank, self.n_heads * self.head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("wq_b", prefix), ) self.wk = ReplicatedLinear( self.hidden_size, self.head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("wk", prefix), ) self.k_norm = V32LayerNorm(self.head_dim) # NOTE: weight_proj is not quantized self.weights_proj = ReplicatedLinear( self.hidden_size, self.n_heads, bias=False, prefix=add_prefix("weights_proj", prefix), ) self.rotary_emb = get_rope_wrapper( rope_head_dim, rotary_dim=rope_head_dim, max_position=max_position_embeddings, base=rope_theta, # type: ignore rope_scaling=rope_scaling, is_neox_style=False, device=get_global_server_args().device, ) self.block_size = block_size self.scale_fmt = scale_fmt self.softmax_scale = self.head_dim**-0.5 @torch.compile(dynamic=True) def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): weights, _ = self.weights_proj(x) weights = weights * self.n_heads**-0.5 weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale return weights def _get_q_k_bf16( self, q_lora: torch.Tensor, x: torch.Tensor, positions: torch.Tensor, enable_dual_stream: bool, ): if enable_dual_stream: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) with deep_gemm_wrapper.configure_deep_gemm_num_sms( self.half_device_sm_count ): query, _ = self.wq_b(q_lora) query = rearrange(query, "l (h d) -> l h d", d=self.head_dim) q_rope, _ = torch.split( query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1, ) with torch.cuda.stream(self.alt_stream): # TODO we should also put DeepGEMM half SM here? key, _ = self.wk(x) key = self.k_norm(key) k_rope, _ = torch.split( key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1, ) current_stream.wait_stream(self.alt_stream) else: query, _ = self.wq_b(q_lora) query = rearrange(query, "l (h d) -> l h d", d=self.head_dim) q_rope, _ = torch.split( query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 ) key, _ = self.wk(x) key = self.k_norm(key) k_rope, _ = torch.split( key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 ) q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope) query[..., : self.rope_head_dim] = q_rope key[..., : self.rope_head_dim] = k_rope if enable_dual_stream: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) query = rotate_activation(query) with torch.cuda.stream(self.alt_stream): key = rotate_activation(key) current_stream.wait_stream(self.alt_stream) else: query = rotate_activation(query) key = rotate_activation(key) return query, key def _get_topk_paged( self, forward_batch: ForwardBatch, layer_id: int, q_fp8: torch.Tensor, weights: torch.Tensor, metadata: BaseIndexerMetadata, ) -> torch.Tensor: if TYPE_CHECKING: assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) page_size = forward_batch.token_to_kv_pool.page_size # NOTE(dark): blocksize = 64 is hardcoded in deep_gemm assert page_size == 64, "only support page size 64" # NOTE(dark): this support extend/decode/decode+graph block_tables = metadata.get_page_table_64() max_seq_len = block_tables.shape[1] * page_size kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer( layer_id=layer_id ) blocksize = page_size if forward_batch.forward_mode.is_target_verify(): seqlens_32 = metadata.get_seqlens_expanded() else: seqlens_32 = metadata.get_seqlens_int32() # NOTE(dark): 132 is SM count on H200/B200, not magic number schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( seqlens_32, blocksize, self.sm_count ) assert len(q_fp8.shape) == 3 q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now assert len(kv_cache_fp8.shape) == 2 block_kv = 64 num_heads_kv = 1 head_dim_with_sf = 132 kv_cache_fp8 = kv_cache_fp8.view( kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf ) assert len(weights.shape) == 3 weights = weights.squeeze(2) logits = deep_gemm.fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, seqlens_32, block_tables, schedule_metadata, max_seq_len, clean_logits=False, ) # NOTE(dark): logits should be cleaned in topk_transform topk_result = metadata.topk_transform(logits, self.index_topk) return topk_result def _get_topk_ragged( self, forward_batch: ForwardBatch, layer_id: int, q_fp8: torch.Tensor, weights: torch.Tensor, metadata: BaseIndexerMetadata, ) -> torch.Tensor: if TYPE_CHECKING: assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) page_size = forward_batch.token_to_kv_pool.page_size assert page_size == 64, "only support page size 64" assert len(weights.shape) == 3 weights = weights.squeeze(-1) k_fp8_list = [] k_scale_list = [] ks_list = [] ke_list = [] offset = 0 seq_lens_expanded = metadata.get_seqlens_expanded() block_tables = metadata.get_page_table_64() assert ( forward_batch.seq_lens_cpu is not None and forward_batch.extend_seq_lens_cpu is not None ) for i in range(forward_batch.batch_size): seq_len = forward_batch.seq_lens_cpu[i].item() assert isinstance(seq_len, int) k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous( layer_id, seq_len, block_tables[i], ) k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous( layer_id, seq_len, block_tables[i], ) extend_seq_len = forward_batch.extend_seq_lens_cpu[i] ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda") ke = ks + seq_lens_expanded[offset : offset + extend_seq_len] k_fp8_list.append(k_fp8) k_scale_list.append(k_scale) ks_list.append(ks) ke_list.append(ke) offset += extend_seq_len k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) kv_fp8 = (k_fp8, k_scale) ks = torch.cat(ks_list, dim=0) ke = torch.cat(ke_list, dim=0) logits = deep_gemm.fp8_mqa_logits( q_fp8[:offset], kv_fp8, weights[:offset], ks, ke, clean_logits=False, ) token_nums, _, _ = q_fp8.shape assert logits.shape[0] == len(seq_lens_expanded) raw_topk_result = metadata.topk_transform(logits, self.index_topk) topk_result = torch.full( (token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32 ) topk_result[:offset] = raw_topk_result return topk_result def forward_indexer( self, q_fp8: torch.Tensor, weights: torch.Tensor, forward_batch: ForwardBatch, topk: int, layer_id: int, ) -> Optional[torch.Tensor]: if not is_npu(): from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index page_size = forward_batch.token_to_kv_pool.page_size assert page_size == 64, "only support page size 64" assert len(weights.shape) == 3 weights = weights.squeeze(-1) # logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) k_fp8_list = [] k_scale_list = [] topk_indices_list = [] block_tables = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : ] strided_indices = torch.arange( 0, block_tables.shape[-1], page_size, device="cuda" ) block_tables = block_tables[:, strided_indices] // page_size q_len_start = 0 for i in range(forward_batch.batch_size): seq_len = forward_batch.seq_lens[i].item() q_len = ( forward_batch.extend_seq_lens_cpu[i] if forward_batch.forward_mode.is_extend() else 1 ) q_len_end = q_len_start + q_len q_fp8_partial = q_fp8[q_len_start:q_len_end] q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous() weights_partial = weights[q_len_start:q_len_end] weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous() k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous( layer_id, seq_len, block_tables[i], ) k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous( layer_id, seq_len, block_tables[i], ) k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous() k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous() index_score = fp8_index( q_fp8_partial, weights_partial, k_fp8, k_scale, ) end_pos = seq_len topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0) pad_len = align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1] topk_indices = torch.nn.functional.pad( topk_indices, (0, pad_len), "constant", -1 ) topk_indices_list.append(topk_indices) q_len_start = q_len_end topk_indices = torch.cat(topk_indices_list, dim=0) return topk_indices def forward_cuda( self, x: torch.Tensor, q_lora: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, layer_id: int, ) -> Optional[torch.Tensor]: if is_hip(): from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant elif not is_npu(): from sglang.srt.layers.attention.nsa.triton_kernel import act_quant if TYPE_CHECKING: assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) metadata = forward_batch.attn_backend.get_indexer_metadata( layer_id, forward_batch ) enable_dual_stream = ( NSA_DUAL_STREAM and self.alt_stream is not None and get_is_capture_mode() and q_lora.shape[0] > 0 and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD ) # skip NSA if attention backend choose to skip this batch if metadata is None: return None query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream) if enable_dual_stream: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt) with torch.cuda.stream(self.alt_stream): k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt) current_stream.wait_stream(self.alt_stream) else: q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt) # k_fp8: (seq_len, head_dim) fp8_e4m3fn # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn if not forward_batch.out_cache_loc.is_contiguous(): forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous() forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer( layer_id=layer_id, loc=forward_batch.out_cache_loc, index_k=k_fp8, index_k_scale=k_scale, ) weights = self._get_logits_head_gate(x, q_scale) if is_cuda(): assert forward_batch.seq_lens_cpu is not None if len(forward_batch.seq_lens_cpu) == 0: # this seems b/c max-pad, no worries? # if x.shape[0] != 0: # print( # "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result" # ) return torch.full( (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda" ) if ( forward_batch.forward_mode.is_decode_or_idle() or forward_batch.forward_mode.is_target_verify() ): topk_result = self._get_topk_paged( forward_batch, layer_id, q_fp8, weights, metadata ) else: topk_result = self._get_topk_ragged( forward_batch, layer_id, q_fp8, weights, metadata ) else: topk_result = self.forward_indexer( q_fp8.contiguous(), weights, forward_batch, topk=self.index_topk, layer_id=layer_id, ) return topk_result def forward_npu( self, x: torch.Tensor, q_lora: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, layer_id: int, ) -> torch.Tensor: import custom_ops # noqa: F401 import torch_npu from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, ) from sglang.srt.utils import get_bool_env_var if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None: actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens else: actual_seq_lengths_kv = ( forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int ) enable_index_cp = ( get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4 ) is_prefill = forward_batch.forward_mode.is_extend() attention_tp_rank = get_attention_tp_rank() attention_tp_size = get_attention_tp_size() cos_sin = self.rotary_emb.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) if is_prefill and enable_index_cp: slice_length = cos.shape[0] // attention_tp_size cos = cos[ slice_length * attention_tp_rank : slice_length * (attention_tp_rank + 1) ] sin = sin[ slice_length * attention_tp_rank : slice_length * (attention_tp_rank + 1) ] slot_mapping = forward_batch.out_cache_loc block_table = forward_batch.attn_backend.forward_metadata.block_tables bs = x.shape[0] q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] q_pe, q_nope = torch.split( q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1, ) # [bs, 64, 64 + 64] q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim) q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view( bs, self.n_heads, self.rope_head_dim ) # [bs, n, d] q = torch.cat([q_pe, q_nope], dim=-1) k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] k = self.k_norm(k_proj) k_pe, k_nope = torch.split( k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1, ) # [bs, 64 + 64] k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim) k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view( bs, 1, self.rope_head_dim ) # [bs, 1, d] k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128] if is_prefill and enable_index_cp: k, local_k = ( torch.empty( (k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]), dtype=k.dtype, device=k.device, ), k, ) get_attention_tp_group().all_gather_into_tensor(k, local_k) forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k) indexer_input = {} if is_prefill: actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device) actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to( device=q.device ) if enable_index_cp: actual_seq_lengths_q -= bs * attention_tp_rank actual_seq_lengths_q = torch.max( actual_seq_lengths_q, torch.zeros_like(actual_seq_lengths_q).to( device=actual_seq_lengths_q.device ), ) actual_seq_lengths_q = torch.min( actual_seq_lengths_q, torch.full(actual_seq_lengths_q.shape, bs).to( device=actual_seq_lengths_q.device ), ) else: if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None: actual_seq_lengths_q = torch.tensor( [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device ) else: actual_seq_lengths_q = ( forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q ) past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id) x = x.view(-1, self.hidden_size) weights = self.weights_proj(x)[0] block_table = ( block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table ) topk_indices = torch.ops.custom.npu_lightning_indexer( query=q.view(-1, self.n_heads, self.head_dim), key=past_key_states, weights=weights, actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32), actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32), block_table=block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=self.index_topk, sparse_mode=3, ) if is_prefill and enable_index_cp: topk_indices, local_topk_indices = ( torch.empty( ( topk_indices.shape[0] * attention_tp_size, topk_indices.shape[1], topk_indices.shape[2], ), dtype=topk_indices.dtype, device=topk_indices.device, ), topk_indices, ) get_attention_tp_group().all_gather_into_tensor( topk_indices, local_topk_indices ) return topk_indices