Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -115,3 +116,11 @@ class AttentionBackend(ABC): ...@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def support_triton(self): def support_triton(self):
"""Check if the current backend supports triton.""" """Check if the current backend supports triton."""
return True return True
def get_indexer_metadata(
self,
layer_id: int,
forward_batch: ForwardBatch,
) -> Optional[BaseIndexerMetadata]:
"""Get the indexer metadata. None means don't support indexer."""
return None
...@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
k_descale, v_descale = None, None k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys. if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
and self.fa_impl_ver != 4
):
if layer.k_scale is not None: if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape) k_descale = layer.k_scale.expand(descale_shape)
......
...@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType ...@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_env_var, get_int_env_var,
is_flashinfer_available, is_flashinfer_available,
...@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
...@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
...@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
...@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
if use_ragged: if use_ragged:
...@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
...@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask = None custom_mask = None
else: else:
assert isinstance( assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput) spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
) )
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
......
...@@ -3,6 +3,7 @@ from typing import Optional, Union ...@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend): ...@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return backend.forward_extend( return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs q, k, v, layer, forward_batch, save_kv_cache, **kwargs
) )
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> Optional[BaseIndexerMetadata]:
backend = self._select_backend(forward_batch.forward_mode)
return backend.get_indexer_metadata(layer_id, forward_batch)
...@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module): ...@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self.rotary_emb = rotary_emb self.rotary_emb = rotary_emb
self.layer_id = layer_id self.layer_id = layer_id
self.has_preprocess_weights = False self.has_preprocess_weights = False
self.dtype = None
self.q_lora_rank = self.q_b_proj.input_size # 1536 self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512 self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128 self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64 self.qk_rope_head_dim = qk_rope_head_dim # 64
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
def preprocess_weights(self, hidden_states): def preprocess_weights(self, hidden_states):
self.dummy = torch.empty( self.dummy = torch.empty(
...@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module): ...@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32) slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping return k_cache, v_cache, slot_mapping
def forward(self, positions, hidden_states, forward_batch, zero_allocator): def forward_absorb_prepare_npu_rms_norm_cache(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch,
zero_allocator,
):
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
self.dtype = hidden_states.dtype
self.cos, self.sin = self.get_sin_cos(positions)
self.kvCache, self.kvCacheRope, self.slotmapping = (
self.get_kv_cache_and_cache_idx(forward_batch)
)
if not self.has_preprocess_weights:
self.has_preprocess_weights = True
cos, sin = self.cos, self.sin
if self.q_lora_rank is not None:
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q_lowrank)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # b*s,n,d
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
latent_cache = latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (B*S,N,1,D)
cache_mode = "PA_BNSD"
self.kvCache = self.kvCache.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.kv_lora_rank,
)
self.kvCacheRope = self.kvCacheRope.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.qk_rope_head_dim,
)
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
latent_cache,
self.kv_a_layernorm.weight,
cos,
sin,
self.slotmapping.to(torch.int64),
self.kvCacheRope,
self.kvCache,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
if not self.has_preprocess_weights: if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states) self.preprocess_weights(hidden_states)
...@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module): ...@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator, zero_allocator,
positions, positions,
) )
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
_is_w8a8 = (
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
and self.qkv_a_proj.quant_method.quantization_config.get_name()
== "w8a8_int8"
)
if _is_w8a8:
return self.forward_mlapo(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb_prepare_npu_rms_norm_cache(
positions, hidden_states, forward_batch, zero_allocator
)
from .topk import fast_topk, fast_topk_transform
__all__ = ["fast_topk", "fast_topk_transform"]
This diff is collapsed.
from __future__ import annotations
from typing import Any
import torch
from .utils import load_kernel_module
def _load_topk_module() -> Any:
"""
Load the index manipulation module.
"""
return load_kernel_module("topk.cu", "topk_kernel")
# TODO(dark): configure out why my cuda impl is a little slower....
# I believe it has something to do with unrolling loops (?)
_USE_TL = True
def fast_topk(
score: torch.Tensor,
indices: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk(score, indices, lengths, _USE_TL)
def fast_topk_transform(
score: torch.Tensor,
lengths: torch.Tensor,
dst_page_table: torch.Tensor,
src_page_table: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk_transform(
score, lengths, dst_page_table, src_page_table, cu_seqlens, _USE_TL
)
from __future__ import annotations
import os
from functools import lru_cache
from typing import Any, Iterable
@lru_cache()
def _prepare_for_load() -> str:
import os
import warnings
warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.utils.cpp_extension"
)
return os.path.dirname(os.path.abspath(__file__))
@lru_cache()
def load_kernel_module(
path: str | Iterable[str],
name: str,
*,
build: str = "build",
cflags: Iterable[str] | None = None,
cuda_flags: Iterable[str] | None = None,
ldflags: Iterable[str] | None = None,
) -> Any:
from torch.utils.cpp_extension import load
if isinstance(path, str):
path = (path,)
abs_path = _prepare_for_load()
build_dir = f"{abs_path}/{build}"
os.makedirs(build_dir, exist_ok=True)
return load(
name=name,
sources=[f"{abs_path}/csrc/{p}" for p in path],
extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
extra_ldflags=list(ldflags or []) or None,
build_directory=build_dir,
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
from sglang.srt.utils import align
# NOTE(dark): flashmla P requires `params.topk % (2*B_TOPK) == 0`,
# where `B_TOPK=64`. So we align to 128 by default.
_TOPK_ALIGNMENT = 128
# TODO(dark): maybe this torch_op can support torch.compile
def _fast_topk_torch(
input: torch.Tensor, seq_lens: torch.Tensor, topk: int, alignment: int
) -> torch.Tensor:
# Fallback to torch.topk
bs, max_seq_len = input.shape
assert len(seq_lens) == bs
# set those out-of-bound input to -inf
padded_max_seq_len = align(max_seq_len, alignment)
positions = torch.arange(
padded_max_seq_len, device=input.device, dtype=seq_lens.dtype
)
positions = positions.unsqueeze(0).expand(bs, -1)
mask = positions >= seq_lens.unsqueeze(1)
# NOTE(dark): just return all valid indices as an optimization
if padded_max_seq_len <= topk:
return positions.masked_fill(mask, -1)
assert topk % alignment == 0
# in-place operation: mask invalid inputs to -inf
input = input.masked_fill_(mask[:, :max_seq_len], float("-inf"))
result = input.topk(topk, dim=-1, sorted=True)
return result.indices.masked_fill_(mask[:, :topk], -1)
def fast_topk_impl(
input: torch.Tensor,
seq_lens: torch.Tensor,
topk: int,
alignment: int = _TOPK_ALIGNMENT,
) -> torch.Tensor:
return _fast_topk_torch(input, seq_lens, topk, alignment)
def fast_topk_transform_fused_cuda(
input: torch.Tensor,
seq_lens: torch.Tensor,
topk: int,
dst_page_table: torch.Tensor,
src_page_table: torch.Tensor,
cu_seqlens_q: torch.Tensor,
alignment: int = _TOPK_ALIGNMENT,
) -> torch.Tensor:
from sglang.srt.layers.attention.nsa.cuda import fast_topk_transform
assert topk == 2048 and topk % alignment == 0
return fast_topk_transform(
score=input,
lengths=seq_lens,
dst_page_table=dst_page_table,
src_page_table=src_page_table,
cu_seqlens=cu_seqlens_q,
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment