Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4298 additions
and
66 deletions
+4298
-66
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+9
-0
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+2
-7
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+38
-14
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+161
-44
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+7
-0
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
+94
-1
python/sglang/srt/layers/attention/nsa/cuda/__init__.py
python/sglang/srt/layers/attention/nsa/cuda/__init__.py
+3
-0
python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
+505
-0
python/sglang/srt/layers/attention/nsa/cuda/topk.py
python/sglang/srt/layers/attention/nsa/cuda/topk.py
+39
-0
python/sglang/srt/layers/attention/nsa/cuda/utils.py
python/sglang/srt/layers/attention/nsa/cuda/utils.py
+44
-0
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
+163
-0
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
+354
-0
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+682
-0
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
+255
-0
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
+774
-0
python/sglang/srt/layers/attention/nsa/topk.py
python/sglang/srt/layers/attention/nsa/topk.py
+65
-0
python/sglang/srt/layers/attention/nsa/transform_index.py
python/sglang/srt/layers/attention/nsa/transform_index.py
+144
-0
python/sglang/srt/layers/attention/nsa/unit_test/get_logits_ut.py
...glang/srt/layers/attention/nsa/unit_test/get_logits_ut.py
+57
-0
python/sglang/srt/layers/attention/nsa/utils.py
python/sglang/srt/layers/attention/nsa/utils.py
+32
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+870
-0
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
852a49c5
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import
torch
import
torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.nsa.nsa_indexer
import
BaseIndexerMetadata
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
...
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def
support_triton
(
self
):
def
support_triton
(
self
):
"""Check if the current backend supports triton."""
"""Check if the current backend supports triton."""
return
True
return
True
def
get_indexer_metadata
(
self
,
layer_id
:
int
,
forward_batch
:
ForwardBatch
,
)
->
Optional
[
BaseIndexerMetadata
]:
"""Get the indexer metadata. None means don't support indexer."""
return
None
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
852a49c5
...
@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
,
v_descale
=
None
,
None
k_descale
,
v_descale
=
None
,
None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
if
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
head_dim
<=
256
:
if
(
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
head_dim
<=
256
and
self
.
fa_impl_ver
!=
4
):
if
layer
.
k_scale
is
not
None
:
if
layer
.
k_scale
is
not
None
:
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
852a49c5
...
@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
...
@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from
sglang.srt.mem_cache.allocator
import
SWATokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
SWATokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.
ngram
_utils
import
Ngram
VerifyInput
from
sglang.srt.speculative.
lookahead
_utils
import
Lookahead
VerifyInput
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_int_env_var
,
get_int_env_var
,
is_flashinfer_available
,
is_flashinfer_available
,
...
@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
decode_wrappers
=
[]
decode_wrappers
=
[]
...
@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
...
@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum
:
int
,
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
use_sliding_window_kv_pool
:
bool
=
False
,
use_sliding_window_kv_pool
:
bool
=
False
,
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
...
@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
...
@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
if
use_ragged
:
if
use_ragged
:
...
@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
...
@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
...
@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
use_ragged
:
bool
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
LookaheadVerifyInput
]
],
use_sliding_window_kv_pool
:
bool
=
False
,
use_sliding_window_kv_pool
:
bool
=
False
,
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
...
@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask
=
None
custom_mask
=
None
else
:
else
:
assert
isinstance
(
assert
isinstance
(
spec_info
,
(
EagleDraftInput
,
EagleVerifyInput
,
Ngram
VerifyInput
)
spec_info
,
(
EagleDraftInput
,
EagleVerifyInput
,
Lookahead
VerifyInput
)
)
)
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
spec_info
.
generate_attn_arg_prefill
(
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
852a49c5
...
@@ -5,13 +5,20 @@ Support attention backend for FlashMLA.
...
@@ -5,13 +5,20 @@ Support attention backend for FlashMLA.
"""
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
triton
import
triton
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
from
sglang.srt.configs.model_config
import
get_nsa_index_topk
,
is_deepseek_nsa
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
FlashInferMLAAttnBackend
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
FlashInferMLAAttnBackend
from
sglang.srt.layers.attention.nsa.quant_k_cache
import
quantize_k_cache
from
sglang.srt.layers.attention.nsa.utils
import
(
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
,
NSA_KV_CACHE_STORE_FP8
,
compute_nsa_seqlens
,
)
from
sglang.srt.layers.attention.utils
import
create_flashmla_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
scaling
=
model_runner
.
model_config
.
scaling
self
.
scaling
=
model_runner
.
model_config
.
scaling
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
kv_cache_dim
=
self
.
kv_lora_rank
+
self
.
qk_rope_head
_dim
self
.
kv_cache_dim
=
model_runner
.
token_to_kv_pool
.
kv_cache
_dim
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
use_nsa
=
is_deepseek_nsa
(
model_runner
.
model_config
.
hf_config
)
self
.
nsa_index_topk
=
(
get_nsa_index_topk
(
model_runner
.
model_config
.
hf_config
)
if
self
.
use_nsa
else
None
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
...
@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
max_seqlen_pad
,
)
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
mla_metadata
,
num_splits
=
_get_mla_metadata_wrapped
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
cache_seqlens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
seq_len_q
=
1
,
1
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
self
.
forward_metadata
=
FlashMLADecodeMetadata
(
self
.
forward_metadata
=
FlashMLADecodeMetadata
(
mla_metadata
,
mla_metadata
,
...
@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
max_seqlen_pad
,
)
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
mla_metadata
,
num_splits
=
_get_mla_metadata_wrapped
(
seq_lens
.
to
(
torch
.
int32
),
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
seq_len_q
=
self
.
num_draft_tokens
,
1
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
...
@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
cuda_graph_kv_indices
=
block_kv_indices
cuda_graph_kv_indices
=
block_kv_indices
if
self
.
num_draft_tokens
:
if
self
.
num_draft_tokens
:
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
=
get_mla_metadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
=
(
torch
.
ones
(
_get_mla_metadata_wrapped
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
cache_seqlens
=
torch
.
ones
(
),
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
self
.
num_draft_tokens
*
self
.
num_q_heads
,
),
1
,
seq_len_q
=
self
.
num_draft_tokens
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
)
else
:
else
:
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
=
get_mla_metadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
=
(
torch
.
ones
(
_get_mla_metadata_wrapped
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
cache_seqlens
=
torch
.
ones
(
),
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
self
.
num_q_heads
,
),
1
,
seq_len_q
=
1
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
)
self
.
cuda_graph_kv_indices
=
cuda_graph_kv_indices
self
.
cuda_graph_kv_indices
=
cuda_graph_kv_indices
...
@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
mla_metadata
,
num_splits
=
_get_mla_metadata_wrapped
(
seq_lens
.
to
(
torch
.
int32
),
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
seq_len_q
=
1
,
1
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
...
@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
mla_metadata
,
num_splits
=
_get_mla_metadata_wrapped
(
seq_lens
.
to
(
torch
.
int32
),
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
seq_len_q
=
self
.
num_draft_tokens
,
1
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
...
@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
mla_metadata
,
num_splits
=
_get_mla_metadata_wrapped
(
seq_lens
.
to
(
torch
.
int32
),
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
seq_len_q
=
1
,
1
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
...
@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
mla_metadata
,
num_splits
=
_get_mla_metadata_wrapped
(
seq_lens
.
to
(
torch
.
int32
),
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
seq_len_q
=
self
.
num_draft_tokens
,
1
,
num_heads_q
=
self
.
num_q_heads
,
num_heads_k
=
1
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
...
@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
...
@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
self
.
data_type
==
torch
.
float8_e4m3fn
:
if
(
not
self
.
use_nsa
)
and
self
.
data_type
==
torch
.
float8_e4m3fn
:
reshape_q_fp8
=
reshape_q
.
to
(
torch
.
float8_e4m3fn
)
reshape_q_fp8
=
reshape_q
.
to
(
torch
.
float8_e4m3fn
)
o
,
_
=
flash_mla_with_kvcache
(
o
,
_
=
flash_mla_with_kvcache
(
q
=
reshape_q_fp8
,
q
=
reshape_q_fp8
,
k_cache
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
,
k_cache
=
k_cache
,
block_table
=
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
block_table
=
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
cache_seqlens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
cache_seqlens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
head_dim_v
=
self
.
kv_lora_rank
,
# TODO Retrieve from config.
head_dim_v
=
self
.
kv_lora_rank
,
# TODO Retrieve from config.
...
@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
else
:
block_table
=
self
.
forward_metadata
.
block_kv_indices
[:
bs
]
cache_seqlens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
extra_kwargs
:
Dict
if
self
.
use_nsa
:
assert
topk_indices
is
not
None
extra_kwargs
=
dict
(
indices
=
_compute_indices_in_kvcache
(
block_table
=
block_table
,
topk_indices
=
topk_indices
.
to
(
torch
.
int32
),
page_size
=
self
.
page_size
,
),
# doc says it is not used, but if pass in None then error
block_table
=
block_table
,
is_fp8_kvcache
=
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
,
)
cache_seqlens
=
compute_nsa_seqlens
(
cache_seqlens
,
nsa_index_topk
=
self
.
nsa_index_topk
)
else
:
extra_kwargs
=
dict
(
block_table
=
block_table
,
causal
=
True
,
)
if
(
self
.
use_nsa
and
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
and
not
NSA_KV_CACHE_STORE_FP8
):
# inefficiently quantize the whole cache
k_cache
=
quantize_k_cache
(
k_cache
)
# todo: need check all causal True or False?
# todo: need check all causal True or False?
o
,
_
=
flash_mla_with_kvcache
(
o
,
_
=
flash_mla_with_kvcache
(
q
=
reshape_q
,
q
=
reshape_q
,
k_cache
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
),
k_cache
=
k_cache
,
block_table
=
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
cache_seqlens
=
cache_seqlens
,
cache_seqlens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
head_dim_v
=
self
.
kv_lora_rank
,
# TODO Retrieve from config.
head_dim_v
=
self
.
kv_lora_rank
,
# TODO Retrieve from config.
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
**
extra_kwargs
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend:
...
@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend:
)
)
self
.
common_template
(
forward_batch
,
call_fn
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
_get_mla_metadata_wrapped
(
*
,
cache_seqlens
:
torch
.
Tensor
,
seq_len_q
:
int
,
num_heads_q
:
int
,
num_heads_k
:
int
,
nsa_index_topk
:
Optional
[
int
],
):
if
nsa_index_topk
is
not
None
:
assert
nsa_index_topk
is
not
None
return
get_mla_metadata
(
cache_seqlens
=
cache_seqlens
,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k
=
seq_len_q
*
num_heads_q
//
num_heads_k
,
num_heads_k
=
num_heads_k
,
num_heads_q
=
num_heads_q
,
is_fp8_kvcache
=
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
,
topk
=
nsa_index_topk
,
)
else
:
assert
nsa_index_topk
is
None
return
get_mla_metadata
(
cache_seqlens
=
cache_seqlens
,
num_heads_per_head_k
=
seq_len_q
*
num_heads_q
//
num_heads_k
,
num_heads_k
=
num_heads_k
,
)
# TODO speedup
def
_compute_indices_in_kvcache
(
block_table
,
topk_indices
,
page_size
):
topk_indices_safe
=
topk_indices
.
masked_fill
(
topk_indices
==
-
1
,
0
)
idx0
=
torch
.
arange
(
block_table
.
size
(
0
),
device
=
topk_indices_safe
.
device
).
unsqueeze
(
1
)
block_idx
=
block_table
[
idx0
,
topk_indices_safe
//
page_size
]
offset
=
topk_indices_safe
%
page_size
indices_in_kvcache
=
block_idx
*
page_size
+
offset
# the kernel requires invalid entry to be -1
assert
indices_in_kvcache
.
shape
==
topk_indices
.
shape
indices_in_kvcache
[
topk_indices
==
-
1
]
=
-
1
# return: (batch_size, seqlen_q_ori, topk)
indices_in_kvcache
=
indices_in_kvcache
[:,
None
,
:]
return
indices_in_kvcache
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
852a49c5
...
@@ -3,6 +3,7 @@ from typing import Optional, Union
...
@@ -3,6 +3,7 @@ from typing import Optional, Union
import
torch
import
torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.nsa.nsa_indexer
import
BaseIndexerMetadata
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return
backend
.
forward_extend
(
return
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
**
kwargs
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
**
kwargs
)
)
def
get_indexer_metadata
(
self
,
layer_id
:
int
,
forward_batch
:
ForwardBatch
)
->
Optional
[
BaseIndexerMetadata
]:
backend
=
self
.
_select_backend
(
forward_batch
.
forward_mode
)
return
backend
.
get_indexer_metadata
(
layer_id
,
forward_batch
)
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
View file @
852a49c5
...
@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
...
@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self
.
rotary_emb
=
rotary_emb
self
.
rotary_emb
=
rotary_emb
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
has_preprocess_weights
=
False
self
.
has_preprocess_weights
=
False
self
.
dtype
=
None
self
.
q_lora_rank
=
self
.
q_b_proj
.
input_size
# 1536
self
.
q_lora_rank
=
self
.
q_b_proj
.
input_size
# 1536
self
.
kv_lora_rank
=
self
.
kv_a_layernorm
.
hidden_size
# 512
self
.
kv_lora_rank
=
self
.
kv_a_layernorm
.
hidden_size
# 512
self
.
num_local_heads
=
num_local_heads
# tp
self
.
num_local_heads
=
num_local_heads
# tp
self
.
qk_nope_head_dim
=
qk_nope_head_dim
# 128
self
.
qk_nope_head_dim
=
qk_nope_head_dim
# 128
self
.
qk_rope_head_dim
=
qk_rope_head_dim
# 64
self
.
qk_rope_head_dim
=
qk_rope_head_dim
# 64
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
def
preprocess_weights
(
self
,
hidden_states
):
def
preprocess_weights
(
self
,
hidden_states
):
self
.
dummy
=
torch
.
empty
(
self
.
dummy
=
torch
.
empty
(
...
@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
...
@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping
=
forward_batch
.
out_cache_loc
.
to
(
dtype
=
torch
.
int32
)
slot_mapping
=
forward_batch
.
out_cache_loc
.
to
(
dtype
=
torch
.
int32
)
return
k_cache
,
v_cache
,
slot_mapping
return
k_cache
,
v_cache
,
slot_mapping
def
forward
(
self
,
positions
,
hidden_states
,
forward_batch
,
zero_allocator
):
def
forward_absorb_prepare_npu_rms_norm_cache
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
,
zero_allocator
,
):
bsz
,
_
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
]).
shape
self
.
dtype
=
hidden_states
.
dtype
self
.
cos
,
self
.
sin
=
self
.
get_sin_cos
(
positions
)
self
.
kvCache
,
self
.
kvCacheRope
,
self
.
slotmapping
=
(
self
.
get_kv_cache_and_cache_idx
(
forward_batch
)
)
if
not
self
.
has_preprocess_weights
:
self
.
has_preprocess_weights
=
True
cos
,
sin
=
self
.
cos
,
self
.
sin
if
self
.
q_lora_rank
is
not
None
:
fused_qkv_a_proj_out
=
self
.
qkv_a_proj
(
hidden_states
)[
0
]
q_lowrank
,
latent_cache
=
fused_qkv_a_proj_out
.
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
q
=
self
.
q_a_layernorm
(
q_lowrank
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# b*s,n,d
q_nope
=
q_nope
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
)
q_nope
=
torch
.
matmul
(
q_nope
.
transpose
(
0
,
1
),
self
.
w_kc
).
transpose
(
0
,
1
)
q_pe
=
q_pe
.
view
(
-
1
,
self
.
num_local_heads
,
1
,
self
.
qk_rope_head_dim
)
cos
=
cos
.
view
(
-
1
,
1
,
1
,
self
.
qk_rope_head_dim
)
sin
=
sin
.
view
(
-
1
,
1
,
1
,
self
.
qk_rope_head_dim
)
q_pe
=
torch_npu
.
npu_interleave_rope
(
q_pe
,
cos
,
sin
)
# (B,N,S,D)
q_pe
=
q_pe
.
view
(
cos
.
shape
[
0
],
self
.
num_local_heads
,
self
.
qk_rope_head_dim
)
latent_cache
=
latent_cache
.
view
(
-
1
,
1
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
# (B*S,N,1,D)
cache_mode
=
"PA_BNSD"
self
.
kvCache
=
self
.
kvCache
.
view
(
-
1
,
forward_batch
.
attn_backend
.
page_size
,
1
,
forward_batch
.
attn_backend
.
kv_lora_rank
,
)
self
.
kvCacheRope
=
self
.
kvCacheRope
.
view
(
-
1
,
forward_batch
.
attn_backend
.
page_size
,
1
,
forward_batch
.
attn_backend
.
qk_rope_head_dim
,
)
k_rope
,
k_nope
,
_
,
_
=
torch_npu
.
npu_kv_rmsnorm_rope_cache
(
latent_cache
,
self
.
kv_a_layernorm
.
weight
,
cos
,
sin
,
self
.
slotmapping
.
to
(
torch
.
int64
),
self
.
kvCacheRope
,
self
.
kvCache
,
epsilon
=
self
.
kv_a_layernorm
.
variance_epsilon
,
cache_mode
=
cache_mode
,
)
return
(
q_pe
,
k_rope
,
q_nope
,
k_nope
,
forward_batch
,
zero_allocator
,
positions
)
def
forward_mlapo
(
self
,
positions
,
hidden_states
,
forward_batch
,
zero_allocator
):
input_dtype
=
hidden_states
.
dtype
input_dtype
=
hidden_states
.
dtype
if
not
self
.
has_preprocess_weights
:
if
not
self
.
has_preprocess_weights
:
self
.
preprocess_weights
(
hidden_states
)
self
.
preprocess_weights
(
hidden_states
)
...
@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
...
@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator
,
zero_allocator
,
positions
,
positions
,
)
)
def
forward
(
self
,
positions
,
hidden_states
,
forward_batch
,
zero_allocator
):
_is_w8a8
=
(
hasattr
(
self
.
qkv_a_proj
.
quant_method
,
"quantization_config"
)
and
self
.
qkv_a_proj
.
quant_method
.
quantization_config
.
get_name
()
==
"w8a8_int8"
)
if
_is_w8a8
:
return
self
.
forward_mlapo
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
return
self
.
forward_absorb_prepare_npu_rms_norm_cache
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
python/sglang/srt/layers/attention/nsa/cuda/__init__.py
0 → 100644
View file @
852a49c5
from
.topk
import
fast_topk
,
fast_topk_transform
__all__
=
[
"fast_topk"
,
"fast_topk_transform"
]
python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
0 → 100644
View file @
852a49c5
#include <ATen/core/TensorBase.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/python.h>
namespace
{
constexpr
int
TopK
=
2048
;
constexpr
int
kThreadsPerBlock
=
1024
;
constexpr
size_t
kSmem
=
32
*
1024
*
sizeof
(
uint32_t
);
// 128KB
struct
FastTopKParams
{
const
float
*
__restrict__
input
;
// [B, input_stride]
int32_t
*
__restrict__
indices
;
// [B, TopK]
int32_t
*
__restrict__
lengths
;
// [B]
int64_t
input_stride
;
bool
use_tilelang
;
};
// when length <= TopK, we can directly write the indices
__device__
void
naive_topk_cuda
(
const
float
*
__restrict__
score
,
int32_t
*
__restrict__
indice
,
int32_t
length
)
{
const
auto
tid
=
threadIdx
.
x
;
for
(
int
i
=
tid
;
i
<
TopK
;
i
+=
kThreadsPerBlock
)
{
indice
[
i
]
=
(
i
<
length
)
?
i
:
-
1
;
}
}
// keep the first `length` entries, set others to -1
__device__
void
naive_topk_transform
(
const
float
*
__restrict__
score
,
int32_t
length
,
int32_t
*
__restrict__
dst_page_table
,
const
int32_t
*
__restrict__
src_page_table
)
{
const
auto
tid
=
threadIdx
.
x
;
for
(
auto
i
=
tid
;
i
<
TopK
;
i
+=
kThreadsPerBlock
)
{
dst_page_table
[
i
]
=
(
i
<
length
)
?
src_page_table
[
i
]
:
-
1
;
}
}
__device__
__forceinline__
uint8_t
convert_to_uint8
(
float
x
)
{
__half
h
=
__float2half_rn
(
x
);
uint16_t
bits
=
__half_as_ushort
(
h
);
uint16_t
key
=
(
bits
&
0x8000
)
?
static_cast
<
uint16_t
>
(
~
bits
&
0xFFFF
)
:
static_cast
<
uint16_t
>
(
bits
|
0x8000
);
return
static_cast
<
uint8_t
>
(
key
>>
8
);
}
__device__
__forceinline__
uint32_t
convert_to_uint32
(
float
x
)
{
uint32_t
bits
=
__float_as_uint
(
x
);
return
(
bits
&
0x80000000u
)
?
(
~
bits
&
0xFFFFFFFFu
)
:
(
bits
|
0x80000000u
);
}
template
<
bool
Is_Epilogue
=
false
,
typename
Indexer
,
typename
Loader
,
int
LENGTH
,
int
MAX_REMAIN
>
__device__
__forceinline__
auto
radix_topk
(
Indexer
indexer
,
Loader
loader
,
uint32_t
length
,
int
topk
,
int
*
__restrict__
index
,
int
&
__restrict__
s_counter
,
int
(
&
__restrict__
s_histogram
)[
LENGTH
],
int
&
__restrict__
s_remain_cnt
,
int
(
&
__restrict__
s_remain_idx
)[
MAX_REMAIN
])
->
int
{
constexpr
auto
RADIX
=
LENGTH
-
1
;
static_assert
(
RADIX
>
1
&&
(
RADIX
&
(
RADIX
-
1
))
==
0
,
"RADIX must be power of 2"
);
static_assert
(
RADIX
<=
kThreadsPerBlock
);
__shared__
uint32_t
s_threshold_bin_id
;
const
auto
tx
=
threadIdx
.
x
;
if
(
tx
<
RADIX
+
1
)
s_histogram
[
tx
]
=
0
;
__syncthreads
();
/// NOTE: Use uint32_t as the index
for
(
auto
i
=
tx
;
i
<
length
;
i
+=
kThreadsPerBlock
)
{
const
auto
idx
=
indexer
(
i
);
const
auto
bin
=
loader
(
idx
);
::
atomicAdd
(
&
s_histogram
[
bin
],
1
);
}
__syncthreads
();
// cumsum (descending)
if
(
tx
==
0
)
{
s_histogram
[
RADIX
]
=
0
;
s_remain_cnt
=
0
;
for
(
int
i
=
RADIX
-
2
;
i
>=
0
;
--
i
)
{
s_histogram
[
i
]
+=
s_histogram
[
i
+
1
];
}
// threshold bin
for
(
int
i
=
0
;
i
<
RADIX
;
i
++
)
{
if
(
s_histogram
[
i
]
>=
topk
&&
s_histogram
[
i
+
1
]
<
topk
)
{
s_threshold_bin_id
=
i
;
break
;
}
}
}
__syncthreads
();
const
auto
threshold_bin
=
s_threshold_bin_id
;
const
auto
new_topk
=
topk
-
s_histogram
[
threshold_bin
+
1
];
for
(
auto
i
=
tx
;
i
<
length
;
i
+=
kThreadsPerBlock
)
{
const
auto
idx
=
indexer
(
i
);
const
auto
bin_id
=
static_cast
<
uint32_t
>
(
loader
(
idx
));
if
(
bin_id
>
threshold_bin
)
{
index
[
::
atomicAdd
(
&
s_counter
,
1
)]
=
idx
;
}
else
if
(
bin_id
==
threshold_bin
&&
new_topk
>
0
)
{
if
constexpr
(
Is_Epilogue
)
{
index
[
::
atomicAdd
(
&
s_counter
,
1
)]
=
idx
;
}
else
{
if
(
const
auto
cnt
=
::
atomicAdd
(
&
s_remain_cnt
,
1
);
C10_LIKELY
(
cnt
<
MAX_REMAIN
))
{
s_remain_idx
[
cnt
]
=
idx
;
}
}
}
}
__syncthreads
();
return
new_topk
;
}
__device__
void
fast_topk_cuda
(
const
float
*
__restrict__
input
,
int
*
__restrict__
index
,
int
length
,
int
topk
=
TopK
)
{
constexpr
auto
RADIX
=
256
;
constexpr
auto
SMEM_INPUT_SIZE
=
kSmem
/
(
2
*
sizeof
(
int
));
__shared__
int
s_histogram
[
RADIX
+
1
];
__shared__
int
s_num_input
[
2
];
__shared__
int
s_counter
;
// allocate for two rounds
extern
__shared__
int
s_input_idx
[][
SMEM_INPUT_SIZE
];
s_counter
=
0
;
// collect candidates
const
auto
indexer
=
[](
int
idx
)
{
return
idx
;
};
const
auto
loader
=
[
&
input
](
int
idx
)
{
return
convert_to_uint8
(
input
[
idx
]);
};
int
new_topk
=
radix_topk
(
indexer
,
loader
,
length
,
topk
,
index
,
s_counter
,
s_histogram
,
s_num_input
[
0
],
s_input_idx
[
0
]);
if
(
new_topk
<=
0
)
return
;
// round 0
const
auto
indexer_0
=
[](
int
idx
)
{
return
s_input_idx
[
0
][
idx
];
};
const
auto
loader_0
=
[
&
input
](
int
idx
)
{
return
(
convert_to_uint32
(
input
[
idx
])
>>
24
)
&
0xFF
;
};
new_topk
=
radix_topk
(
indexer_0
,
loader_0
,
s_num_input
[
0
],
new_topk
,
index
,
s_counter
,
s_histogram
,
s_num_input
[
1
],
s_input_idx
[
1
]);
if
(
new_topk
<=
0
)
return
;
// round 1
const
auto
indexer_1
=
[](
int
idx
)
{
return
s_input_idx
[
1
][
idx
];
};
const
auto
loader_1
=
[
&
input
](
int
idx
)
{
return
(
convert_to_uint32
(
input
[
idx
])
>>
16
)
&
0xFF
;
};
new_topk
=
radix_topk
(
indexer_1
,
loader_1
,
s_num_input
[
1
],
new_topk
,
index
,
s_counter
,
s_histogram
,
s_num_input
[
0
],
s_input_idx
[
0
]);
if
(
new_topk
<=
0
)
return
;
// round 2
const
auto
loader_2
=
[
&
input
](
int
idx
)
{
return
(
convert_to_uint32
(
input
[
idx
])
>>
8
)
&
0xFF
;
};
new_topk
=
radix_topk
(
indexer_0
,
loader_2
,
s_num_input
[
0
],
new_topk
,
index
,
s_counter
,
s_histogram
,
s_num_input
[
1
],
s_input_idx
[
1
]);
if
(
new_topk
<=
0
)
return
;
// round 3
const
auto
loader_3
=
[
&
input
](
int
idx
)
{
return
convert_to_uint32
(
input
[
idx
])
&
0xFF
;
};
// epilogue
radix_topk
<
true
>
(
indexer_1
,
loader_3
,
s_num_input
[
1
],
new_topk
,
index
,
s_counter
,
s_histogram
,
s_num_input
[
0
],
s_input_idx
[
0
]);
}
__device__
void
fast_topk_cuda_tl
(
const
float
*
__restrict__
input
,
int
*
__restrict__
index
,
int
length
,
int
topk
=
TopK
)
{
constexpr
auto
BLOCK_SIZE
=
1024
;
constexpr
auto
RADIX
=
256
;
constexpr
auto
SMEM_INPUT_SIZE
=
kSmem
/
(
2
*
sizeof
(
int
));
__shared__
int
s_threshold_bin_id
;
__shared__
int
s_histogram
[
RADIX
+
1
];
__shared__
int
s_num_input
[
2
];
__shared__
int
s_counter
;
// allocate for two rounds
extern
__shared__
int
s_input_idx
[][
SMEM_INPUT_SIZE
];
int
tx
=
threadIdx
.
x
;
// stage 1: 8bit coarse histogram
if
(
tx
<
RADIX
+
1
)
s_histogram
[
tx
]
=
0
;
__syncthreads
();
for
(
int
idx
=
tx
;
idx
<
length
;
idx
+=
BLOCK_SIZE
)
{
const
auto
bin
=
convert_to_uint8
(
input
[
idx
]);
::
atomicAdd
(
&
s_histogram
[
bin
],
1
);
}
__syncthreads
();
// cumsum (descending)
if
(
tx
==
0
)
{
for
(
int
i
=
RADIX
-
2
;
i
>=
0
;
--
i
)
{
s_histogram
[
i
]
+=
s_histogram
[
i
+
1
];
}
// threshold bin
for
(
int
i
=
0
;
i
<
RADIX
;
i
++
)
{
if
(
s_histogram
[
i
]
>=
topk
&&
s_histogram
[
i
+
1
]
<
topk
)
{
s_threshold_bin_id
=
i
;
break
;
}
}
s_num_input
[
0
]
=
0
;
s_counter
=
0
;
}
__syncthreads
();
int
threshold_bin
=
s_threshold_bin_id
;
int
new_topk
=
topk
-
s_histogram
[
threshold_bin
+
1
];
// collect candidates
for
(
int
idx
=
tx
;
idx
<
length
;
idx
+=
BLOCK_SIZE
)
{
const
auto
bin_id
=
static_cast
<
int
>
(
convert_to_uint8
(
input
[
idx
]));
if
(
bin_id
>
threshold_bin
)
{
int
pos
=
::
atomicAdd
(
&
s_counter
,
1
);
index
[
pos
]
=
idx
;
}
else
if
(
bin_id
==
threshold_bin
&&
new_topk
>
0
)
{
int
pos
=
::
atomicAdd
(
&
s_num_input
[
0
],
1
);
if
(
pos
<
SMEM_INPUT_SIZE
)
{
[[
likely
]]
s_input_idx
[
0
][
pos
]
=
idx
;
}
}
}
__syncthreads
();
// stage 2: refine with 8bit radix passes
#pragma unroll 4
for
(
int
round
=
0
;
round
<
4
;
++
round
)
{
if
(
new_topk
<=
0
)
break
;
int
r_idx
=
round
%
2
;
// reset
if
(
tx
<
RADIX
+
1
)
s_histogram
[
tx
]
=
0
;
__syncthreads
();
int
num_input
=
s_num_input
[
r_idx
];
for
(
int
i
=
tx
;
i
<
num_input
;
i
+=
BLOCK_SIZE
)
{
int
idx
=
s_input_idx
[
r_idx
][
i
];
uint32_t
bin32
=
(
convert_to_uint32
(
input
[
idx
])
>>
(
24
-
round
*
8
))
&
0xFF
;
::
atomicAdd
(
&
s_histogram
[
bin32
],
1
);
}
__syncthreads
();
if
(
tx
==
0
)
{
for
(
int
i
=
RADIX
-
2
;
i
>=
0
;
--
i
)
s_histogram
[
i
]
+=
s_histogram
[
i
+
1
];
for
(
int
i
=
0
;
i
<
RADIX
;
i
++
)
{
if
(
s_histogram
[
i
]
>=
new_topk
&&
s_histogram
[
i
+
1
]
<
new_topk
)
{
s_threshold_bin_id
=
i
;
break
;
}
}
s_num_input
[
r_idx
^
1
]
=
0
;
}
__syncthreads
();
new_topk
-=
s_histogram
[
s_threshold_bin_id
+
1
];
int
threshold_bin
=
s_threshold_bin_id
;
for
(
int
i
=
tx
;
i
<
num_input
;
i
+=
BLOCK_SIZE
)
{
int
idx
=
s_input_idx
[
r_idx
][
i
];
uint32_t
bin32
=
(
convert_to_uint32
(
input
[
idx
])
>>
(
24
-
round
*
8
))
&
0xFF
;
if
(
bin32
>
threshold_bin
)
{
int
pos
=
::
atomicAdd
(
&
s_counter
,
1
);
index
[
pos
]
=
idx
;
}
else
if
(
bin32
==
threshold_bin
&&
new_topk
>
0
)
{
if
(
round
==
3
)
{
int
pos
=
::
atomicAdd
(
&
s_counter
,
1
);
index
[
pos
]
=
idx
;
}
else
{
int
pos
=
::
atomicAdd
(
&
s_num_input
[
r_idx
^
1
],
1
);
if
(
pos
<
SMEM_INPUT_SIZE
)
s_input_idx
[
r_idx
^
1
][
pos
]
=
idx
;
}
}
}
__syncthreads
();
}
}
__global__
void
topk_kernel
(
const
FastTopKParams
params
)
{
const
auto
&
[
input
,
indices
,
lengths
,
input_stride
,
use_tilelang
]
=
params
;
const
auto
bid
=
blockIdx
.
x
;
const
auto
length
=
*
(
lengths
+
bid
);
const
auto
indice
=
indices
+
bid
*
TopK
;
const
auto
score
=
input
+
bid
*
input_stride
;
if
(
length
<=
TopK
)
{
return
naive_topk_cuda
(
score
,
indice
,
length
);
}
else
{
if
(
use_tilelang
)
{
return
fast_topk_cuda_tl
(
score
,
indice
,
length
);
}
else
{
return
fast_topk_cuda
(
score
,
indice
,
length
);
}
}
}
__global__
void
topk_kernel_transform_decode
(
// decode
const
FastTopKParams
params
,
int32_t
*
__restrict__
dst_page_table
,
const
int32_t
*
__restrict__
src_page_table
,
const
int64_t
src_stride
)
{
const
auto
&
[
input
,
_
,
lengths
,
input_stride
,
use_tilelang
]
=
params
;
const
auto
bid
=
blockIdx
.
x
;
const
auto
tid
=
threadIdx
.
x
;
const
auto
length
=
*
(
lengths
+
bid
);
const
auto
src_page_entry
=
src_page_table
+
bid
*
src_stride
;
const
auto
dst_page_entry
=
dst_page_table
+
bid
*
TopK
;
const
auto
score
=
input
+
bid
*
input_stride
;
if
(
length
<=
TopK
)
{
return
naive_topk_transform
(
score
,
length
,
dst_page_entry
,
src_page_entry
);
}
else
{
__shared__
int
s_indices
[
TopK
];
if
(
use_tilelang
)
{
fast_topk_cuda_tl
(
score
,
s_indices
,
length
);
}
else
{
fast_topk_cuda
(
score
,
s_indices
,
length
);
}
// copy src[s_indices] to dst, we manually unroll here
static_assert
(
TopK
%
kThreadsPerBlock
==
0
);
static_assert
(
TopK
/
kThreadsPerBlock
==
2
);
const
auto
idx_0
=
tid
;
const
auto
pos_0
=
s_indices
[
idx_0
];
dst_page_entry
[
idx_0
]
=
src_page_entry
[
pos_0
];
const
auto
idx_1
=
tid
+
kThreadsPerBlock
;
const
auto
pos_1
=
s_indices
[
idx_1
];
dst_page_entry
[
idx_1
]
=
src_page_entry
[
pos_1
];
}
}
__global__
void
topk_kernel_transform_prefill
(
// prefill
const
FastTopKParams
params
,
int32_t
*
__restrict__
dst_page_table
,
const
int32_t
*
__restrict__
src_page_table
,
const
int64_t
src_stride
,
const
int32_t
*
__restrict__
cu_seqlens
,
const
int64_t
prefill_bs
)
{
const
auto
&
[
input
,
_
,
lengths
,
input_stride
,
use_tilelang
]
=
params
;
const
auto
bid
=
blockIdx
.
x
;
const
auto
tid
=
threadIdx
.
x
;
const
auto
length
=
*
(
lengths
+
bid
);
const
auto
dst_page_entry
=
dst_page_table
+
bid
*
TopK
;
const
auto
score
=
input
+
bid
*
input_stride
;
/// NOTE: prefill bs is usually small, we can just use a simple loop here
/// We ensure that last cu_seqlens is equal to number of blocks launched
assert
(
gridDim
.
x
==
cu_seqlens
[
prefill_bs
]
&&
"Invalid cu_seqlens in topk-transform-prefill"
);
__shared__
const
int32_t
*
s_src_page_entry
;
if
(
tid
==
0
)
{
for
(
int64_t
offset
=
0
;
offset
<
prefill_bs
;
++
offset
)
{
if
(
bid
<
cu_seqlens
[
offset
+
1
])
{
s_src_page_entry
=
src_page_table
+
offset
*
src_stride
;
break
;
}
}
}
__syncthreads
();
const
auto
src_page_entry
=
s_src_page_entry
;
if
(
length
<=
TopK
)
{
return
naive_topk_transform
(
score
,
length
,
dst_page_entry
,
src_page_entry
);
}
else
{
__shared__
int
s_indices
[
TopK
];
if
(
use_tilelang
)
{
fast_topk_cuda_tl
(
score
,
s_indices
,
length
);
}
else
{
fast_topk_cuda
(
score
,
s_indices
,
length
);
}
// copy src[s_indices] to dst, we manually unroll here
static_assert
(
TopK
%
kThreadsPerBlock
==
0
);
static_assert
(
TopK
/
kThreadsPerBlock
==
2
);
const
auto
idx_0
=
tid
;
const
auto
pos_0
=
s_indices
[
idx_0
];
dst_page_entry
[
idx_0
]
=
src_page_entry
[
pos_0
];
const
auto
idx_1
=
tid
+
kThreadsPerBlock
;
const
auto
pos_1
=
s_indices
[
idx_1
];
dst_page_entry
[
idx_1
]
=
src_page_entry
[
pos_1
];
}
}
auto
get_params
(
at
::
Tensor
score
,
at
::
Tensor
lengths
,
bool
use_tilelang
,
std
::
optional
<
at
::
Tensor
>
indices_opt
=
std
::
nullopt
)
->
FastTopKParams
{
const
auto
B
=
score
.
size
(
0
);
TORCH_CHECK
(
score
.
dim
()
==
2
&&
score
.
stride
(
1
)
==
1
);
TORCH_CHECK
(
lengths
.
dim
()
==
1
&&
lengths
.
is_contiguous
());
TORCH_CHECK
(
lengths
.
size
(
0
)
==
B
);
int32_t
*
indices_data_ptr
=
nullptr
;
if
(
indices_opt
.
has_value
())
{
const
auto
&
indices
=
indices_opt
.
value
();
TORCH_CHECK
(
indices
.
dim
()
==
2
&&
indices
.
is_contiguous
());
TORCH_CHECK
(
indices
.
size
(
0
)
==
B
);
TORCH_CHECK
(
indices
.
size
(
1
)
==
TopK
);
indices_data_ptr
=
indices
.
data_ptr
<
int32_t
>
();
}
return
FastTopKParams
{
.
input
=
score
.
data_ptr
<
float
>
(),
.
indices
=
indices_data_ptr
,
.
lengths
=
lengths
.
data_ptr
<
int32_t
>
(),
.
input_stride
=
score
.
stride
(
0
),
.
use_tilelang
=
use_tilelang
,
};
}
template
<
auto
*
f
,
size_t
max_dynamic_smem
>
auto
setup_kernel_smem_once
()
->
void
{
[[
maybe_unused
]]
static
const
auto
result
=
[]
{
return
::
cudaFuncSetAttribute
(
f
,
::
cudaFuncAttributeMaxDynamicSharedMemorySize
,
max_dynamic_smem
);
}();
TORCH_CHECK
(
result
==
cudaSuccess
,
"set_up_kernel_once failed:"
,
::
cudaGetErrorString
(
result
));
}
auto
fast_topk_interface
(
at
::
Tensor
score
,
at
::
Tensor
indices
,
at
::
Tensor
lengths
,
bool
use_tilelang
)
->
void
{
const
auto
params
=
get_params
(
score
,
lengths
,
use_tilelang
,
indices
);
const
auto
B
=
score
.
size
(
0
);
const
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
auto
grid
=
dim3
{
static_cast
<
uint32_t
>
(
B
)};
const
auto
block
=
dim3
{
kThreadsPerBlock
};
setup_kernel_smem_once
<
topk_kernel
,
kSmem
>
();
topk_kernel
<<<
grid
,
block
,
kSmem
,
stream
>>>
(
params
);
const
auto
result
=
cudaGetLastError
();
TORCH_CHECK
(
result
==
cudaSuccess
,
"topk kernel failed:"
,
::
cudaGetErrorString
(
result
));
}
auto
fast_topk_transform_interface
(
at
::
Tensor
score
,
at
::
Tensor
lengths
,
at
::
Tensor
dst_page_table
,
at
::
Tensor
src_page_table
,
at
::
Tensor
cu_seqlens
,
bool
use_tilelang
)
->
void
{
const
auto
params
=
get_params
(
score
,
lengths
,
use_tilelang
);
const
auto
B
=
score
.
size
(
0
);
TORCH_CHECK
(
dst_page_table
.
dim
()
==
2
&&
dst_page_table
.
is_contiguous
());
TORCH_CHECK
(
src_page_table
.
dim
()
==
2
&&
src_page_table
.
stride
(
1
)
==
1
);
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
&&
cu_seqlens
.
is_contiguous
());
const
auto
prefill_bs
=
cu_seqlens
.
size
(
0
)
-
1
;
TORCH_CHECK
(
dst_page_table
.
size
(
0
)
==
B
);
TORCH_CHECK
(
dst_page_table
.
size
(
1
)
==
TopK
);
TORCH_CHECK
(
src_page_table
.
size
(
0
)
==
prefill_bs
);
TORCH_CHECK
(
prefill_bs
<=
B
);
// prefill_bs should be smaller than expanded bs
// launch kernel
const
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
auto
grid
=
dim3
{
static_cast
<
uint32_t
>
(
B
)};
const
auto
block
=
dim3
{
kThreadsPerBlock
};
const
auto
src_stride
=
src_page_table
.
stride
(
0
);
// dispatch to decode or prefill
const
auto
is_decode
=
(
prefill_bs
==
B
);
if
(
is_decode
)
{
setup_kernel_smem_once
<
topk_kernel_transform_decode
,
kSmem
>
();
topk_kernel_transform_decode
<<<
grid
,
block
,
kSmem
,
stream
>>>
(
params
,
dst_page_table
.
data_ptr
<
int32_t
>
(),
src_page_table
.
data_ptr
<
int32_t
>
(),
src_stride
);
}
else
{
setup_kernel_smem_once
<
topk_kernel_transform_prefill
,
kSmem
>
();
topk_kernel_transform_prefill
<<<
grid
,
block
,
kSmem
,
stream
>>>
(
params
,
dst_page_table
.
data_ptr
<
int32_t
>
(),
src_page_table
.
data_ptr
<
int32_t
>
(),
src_stride
,
cu_seqlens
.
data_ptr
<
int32_t
>
(),
prefill_bs
);
}
const
auto
result
=
cudaGetLastError
();
TORCH_CHECK
(
result
==
cudaSuccess
,
"topk kernel failed:"
,
::
cudaGetErrorString
(
result
));
}
}
// namespace
PYBIND11_MODULE
(
topk_kernel
,
m
)
{
m
.
def
(
"fast_topk"
,
&
fast_topk_interface
);
m
.
def
(
"fast_topk_transform"
,
&
fast_topk_transform_interface
);
}
python/sglang/srt/layers/attention/nsa/cuda/topk.py
0 → 100644
View file @
852a49c5
from
__future__
import
annotations
from
typing
import
Any
import
torch
from
.utils
import
load_kernel_module
def
_load_topk_module
()
->
Any
:
"""
Load the index manipulation module.
"""
return
load_kernel_module
(
"topk.cu"
,
"topk_kernel"
)
# TODO(dark): configure out why my cuda impl is a little slower....
# I believe it has something to do with unrolling loops (?)
_USE_TL
=
True
def
fast_topk
(
score
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
lengths
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
_load_topk_module
().
fast_topk
(
score
,
indices
,
lengths
,
_USE_TL
)
def
fast_topk_transform
(
score
:
torch
.
Tensor
,
lengths
:
torch
.
Tensor
,
dst_page_table
:
torch
.
Tensor
,
src_page_table
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
_load_topk_module
().
fast_topk_transform
(
score
,
lengths
,
dst_page_table
,
src_page_table
,
cu_seqlens
,
_USE_TL
)
python/sglang/srt/layers/attention/nsa/cuda/utils.py
0 → 100644
View file @
852a49c5
from
__future__
import
annotations
import
os
from
functools
import
lru_cache
from
typing
import
Any
,
Iterable
@
lru_cache
()
def
_prepare_for_load
()
->
str
:
import
os
import
warnings
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
module
=
"torch.utils.cpp_extension"
)
return
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
@
lru_cache
()
def
load_kernel_module
(
path
:
str
|
Iterable
[
str
],
name
:
str
,
*
,
build
:
str
=
"build"
,
cflags
:
Iterable
[
str
]
|
None
=
None
,
cuda_flags
:
Iterable
[
str
]
|
None
=
None
,
ldflags
:
Iterable
[
str
]
|
None
=
None
,
)
->
Any
:
from
torch.utils.cpp_extension
import
load
if
isinstance
(
path
,
str
):
path
=
(
path
,)
abs_path
=
_prepare_for_load
()
build_dir
=
f
"
{
abs_path
}
/
{
build
}
"
os
.
makedirs
(
build_dir
,
exist_ok
=
True
)
return
load
(
name
=
name
,
sources
=
[
f
"
{
abs_path
}
/csrc/
{
p
}
"
for
p
in
path
],
extra_cflags
=
list
(
cflags
or
[])
or
[
"-O3"
,
"-std=c++17"
],
extra_cuda_cflags
=
list
(
cuda_flags
or
[])
or
[
"-O3"
,
"-std=c++17"
],
extra_ldflags
=
list
(
ldflags
or
[])
or
None
,
build_directory
=
build_dir
,
)
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
0 → 100644
View file @
852a49c5
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.nsa.utils
import
NSA_DEQUANT_K_CACHE_FAST
def
dequantize_k_cache
(
quant_k_cache
):
if
NSA_DEQUANT_K_CACHE_FAST
:
return
_dequantize_k_cache_fast_wrapped
(
quant_k_cache
)
else
:
return
_dequantize_k_cache_slow
(
quant_k_cache
)
def
_dequantize_k_cache_slow
(
quant_k_cache
:
torch
.
Tensor
,
# (num_blocks, block_size, 1, bytes_per_token)
dv
:
int
=
512
,
tile_size
:
int
=
128
,
d
:
int
=
576
,
)
->
torch
.
Tensor
:
"""
De-quantize the k-cache
"""
assert
dv
%
tile_size
==
0
num_tiles
=
dv
//
tile_size
num_blocks
,
block_size
,
h_k
,
_
=
quant_k_cache
.
shape
assert
h_k
==
1
result
=
torch
.
empty
(
(
num_blocks
,
block_size
,
d
),
dtype
=
torch
.
bfloat16
,
device
=
quant_k_cache
.
device
)
quant_k_cache
=
quant_k_cache
.
view
(
num_blocks
,
block_size
,
-
1
)
input_nope
=
quant_k_cache
[...,
:
dv
]
input_scale
=
quant_k_cache
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
input_rope
=
quant_k_cache
[...,
dv
+
num_tiles
*
4
:].
view
(
torch
.
bfloat16
)
result
[...,
dv
:]
=
input_rope
for
tile_idx
in
range
(
0
,
num_tiles
):
cur_nope
=
input_nope
[
...,
tile_idx
*
tile_size
:
(
tile_idx
+
1
)
*
tile_size
].
to
(
torch
.
float32
)
cur_scales
=
input_scale
[...,
tile_idx
].
unsqueeze
(
-
1
)
result
[...,
tile_idx
*
tile_size
:
(
tile_idx
+
1
)
*
tile_size
]
=
(
cur_nope
*
cur_scales
)
result
=
result
.
view
(
num_blocks
,
block_size
,
1
,
d
)
return
result
def
_dequantize_k_cache_fast_wrapped
(
quant_k_cache
:
torch
.
Tensor
,
dv
:
int
=
512
,
tile_size
:
int
=
128
,
)
->
torch
.
Tensor
:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks
,
block_size
,
_
,
dim_quant
=
quant_k_cache
.
shape
assert
dv
==
512
assert
dim_quant
==
656
assert
tile_size
==
128
quant_k_cache
=
quant_k_cache
.
view
((
-
1
,
dim_quant
))
output
=
_dequantize_k_cache_fast
(
quant_k_cache
)
return
output
.
view
(
num_blocks
,
block_size
,
1
,
-
1
)
def
_dequantize_k_cache_fast
(
quant_k_cache
,
group_size
:
int
=
128
):
num_tokens
,
dim_quant
=
quant_k_cache
.
shape
assert
quant_k_cache
.
dtype
==
torch
.
float8_e4m3fn
dim_nope
=
512
dim_rope
=
64
num_tiles
=
dim_nope
//
group_size
assert
dim_quant
==
656
output
=
torch
.
empty
(
(
num_tokens
,
dim_nope
+
dim_rope
),
dtype
=
torch
.
bfloat16
,
device
=
quant_k_cache
.
device
,
)
num_blocks_per_token
=
triton
.
cdiv
(
dim_nope
+
dim_rope
,
group_size
)
assert
num_blocks_per_token
==
5
assert
dim_nope
%
group_size
==
0
NUM_NOPE_BLOCKS
=
dim_nope
//
group_size
input_nope_q
=
quant_k_cache
[:,
:
dim_nope
]
input_nope_s
=
quant_k_cache
[:,
dim_nope
:
dim_nope
+
num_tiles
*
4
].
view
(
torch
.
float32
)
input_rope
=
quant_k_cache
[:,
dim_nope
+
num_tiles
*
4
:].
view
(
torch
.
bfloat16
)
_dequantize_k_cache_fast_kernel
[(
num_tokens
,
num_blocks_per_token
)](
output
,
input_nope_q
,
input_nope_s
,
input_rope
,
output
.
stride
(
0
),
input_nope_q
.
stride
(
0
),
input_nope_s
.
stride
(
0
),
input_rope
.
stride
(
0
),
NUM_NOPE_BLOCKS
=
NUM_NOPE_BLOCKS
,
GROUP_SIZE
=
group_size
,
DIM_NOPE
=
dim_nope
,
DIM_ROPE
=
dim_rope
,
)
return
output
@
triton
.
jit
def
_dequantize_k_cache_fast_kernel
(
output_ptr
,
input_nope_q_ptr
,
input_nope_s_ptr
,
input_rope_ptr
,
output_stride_0
:
int
,
input_nope_q_stride_0
:
int
,
input_nope_s_stride_0
:
int
,
input_rope_stride_0
:
int
,
NUM_NOPE_BLOCKS
:
tl
.
constexpr
,
GROUP_SIZE
:
tl
.
constexpr
,
DIM_NOPE
:
tl
.
constexpr
,
DIM_ROPE
:
tl
.
constexpr
,
):
token_id
=
tl
.
program_id
(
0
)
raw_block_id
=
tl
.
program_id
(
1
)
if
raw_block_id
<
NUM_NOPE_BLOCKS
:
# a. dequant nope
effective_block_id
=
raw_block_id
offs_q
=
effective_block_id
*
GROUP_SIZE
+
tl
.
arange
(
0
,
GROUP_SIZE
)
mask
=
offs_q
<
DIM_NOPE
ptr_q
=
input_nope_q_ptr
+
token_id
*
input_nope_q_stride_0
+
offs_q
ptr_s
=
input_nope_s_ptr
+
token_id
*
input_nope_s_stride_0
+
effective_block_id
y_q
=
tl
.
load
(
ptr_q
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y_s
=
tl
.
load
(
ptr_s
)
y
=
(
y_q
*
y_s
).
to
(
output_ptr
.
dtype
.
element_ty
)
dst_ptr
=
output_ptr
+
token_id
*
output_stride_0
+
offs_q
tl
.
store
(
dst_ptr
,
y
,
mask
=
mask
)
else
:
# b. copy rope
effective_block_id
=
raw_block_id
-
NUM_NOPE_BLOCKS
offs
=
effective_block_id
*
GROUP_SIZE
+
tl
.
arange
(
0
,
GROUP_SIZE
)
mask
=
offs
<
DIM_ROPE
src_ptr
=
input_rope_ptr
+
token_id
*
input_rope_stride_0
+
offs
dst_ptr
=
output_ptr
+
token_id
*
output_stride_0
+
DIM_NOPE
+
offs
data
=
tl
.
load
(
src_ptr
,
mask
=
mask
).
to
(
tl
.
bfloat16
)
tl
.
store
(
dst_ptr
,
data
,
mask
=
mask
)
if
__name__
==
"__main__"
:
raise
Exception
(
"UT is in quant_k_cache.py"
)
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
0 → 100644
View file @
852a49c5
from
typing
import
TYPE_CHECKING
import
torch
import
triton
import
triton.language
as
tl
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
NSATokenToKVPool
"""
k: data, 128 item per token, fp8
s: scale, 1 item per token, fp32
"""
class
GetK
:
@
classmethod
def
execute
(
cls
,
*
args
,
**
kwargs
):
return
cls
.
torch_fast
(
*
args
,
**
kwargs
)
@
classmethod
def
slow
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
,
seq_len
:
int
,
page_indices
:
torch
.
Tensor
):
num_pages
=
(
seq_len
+
pool
.
page_size
-
1
)
//
pool
.
page_size
seq_len_
=
num_pages
*
pool
.
page_size
index_k_fp8
=
torch
.
empty
(
(
seq_len_
,
pool
.
index_head_dim
),
dtype
=
torch
.
uint8
,
device
=
pool
.
device
,
)
for
i
in
range
(
num_pages
):
page_index
=
page_indices
[
i
]
index_k_fp8
[
i
*
pool
.
page_size
:
(
i
+
1
)
*
pool
.
page_size
]
=
buf
[
page_index
][:
pool
.
page_size
*
pool
.
index_head_dim
].
view
(
-
1
,
pool
.
index_head_dim
)
return
index_k_fp8
[:
seq_len
]
@
classmethod
def
torch_fast
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
,
seq_len
:
int
,
page_indices
:
torch
.
Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim), uint8
"""
# can handle per 128B instead of per element
# page_indices: (num_pages,), element := a page index
buf_numel_per_page
=
buf
.
shape
[
1
]
num_k_bytes_per_page
=
pool
.
page_size
*
pool
.
index_head_dim
num_k_bytes_per_token
=
pool
.
index_head_dim
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
# flat_buf: (whatever,), uint8
flat_buf
=
buf
.
flatten
()
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
flat_indices
=
(
page_indices
*
buf_numel_per_page
)[:,
None
]
+
torch
.
arange
(
num_k_bytes_per_page
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[
None
,
:]
flat_indices
=
flat_indices
.
flatten
()[:
seq_len
*
num_k_bytes_per_token
]
out
=
flat_buf
[
flat_indices
]
return
out
.
view
(
-
1
,
128
)
class
GetS
:
@
classmethod
def
execute
(
cls
,
*
args
,
**
kwargs
):
return
cls
.
torch_fast
(
*
args
,
**
kwargs
)
@
classmethod
def
slow
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
,
seq_len
:
int
,
page_indices
:
torch
.
Tensor
):
num_pages
=
(
seq_len
+
pool
.
page_size
-
1
)
//
pool
.
page_size
seq_len_
=
num_pages
*
pool
.
page_size
assert
pool
.
index_head_dim
//
pool
.
quant_block_size
==
1
index_k_scale_fp8
=
torch
.
empty
(
(
seq_len_
,
4
),
dtype
=
torch
.
uint8
,
device
=
pool
.
device
,
)
for
i
in
range
(
num_pages
):
page_index
=
page_indices
[
i
]
index_k_scale_fp8
[
i
*
pool
.
page_size
:
(
i
+
1
)
*
pool
.
page_size
]
=
buf
[
page_index
][
pool
.
page_size
*
pool
.
index_head_dim
:].
view
(
-
1
,
4
)
return
index_k_scale_fp8
[:
seq_len
]
@
classmethod
def
torch_fast
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
,
seq_len
:
int
,
page_indices
:
torch
.
Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim // quant_block_size), uint8
"""
buf_numel_per_page
=
buf
.
shape
[
1
]
num_s_bytes_per_page
=
buf
.
shape
[
1
]
-
pool
.
page_size
*
pool
.
index_head_dim
num_s_bytes_per_token
=
pool
.
index_head_dim
//
pool
.
quant_block_size
*
4
s_offset_in_page
=
pool
.
page_size
*
pool
.
index_head_dim
flat_buf
=
buf
.
flatten
()
flat_indices
=
(
(
page_indices
*
buf_numel_per_page
)[:,
None
]
+
torch
.
arange
(
num_s_bytes_per_page
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[
None
,
:
]
+
s_offset_in_page
)
flat_indices
=
flat_indices
.
flatten
()[:
seq_len
*
num_s_bytes_per_token
]
out
=
flat_buf
[
flat_indices
]
return
out
.
view
(
-
1
,
4
)
class
SetK
:
@
classmethod
def
execute
(
cls
,
*
args
,
buf
,
**
kwargs
):
return
cls
.
torch_fast
(
*
args
,
**
kwargs
,
buf
=
buf
)
@
classmethod
def
slow
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
index_k
:
torch
.
Tensor
,
):
for
i
in
range
(
len
(
loc
)):
page_index
=
loc
[
i
]
//
pool
.
page_size
offset
=
loc
[
i
]
%
pool
.
page_size
buf
[
page_index
,
offset
*
pool
.
index_head_dim
:
(
offset
+
1
)
*
pool
.
index_head_dim
,
]
=
index_k
[
i
].
view
(
torch
.
uint8
)
@
classmethod
def
torch_fast
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
index_k
:
torch
.
Tensor
,
):
(
num_tokens_to_write
,)
=
loc
.
shape
buf_numel_per_page
=
buf
.
shape
[
1
]
num_k_bytes_per_token
=
pool
.
index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index
=
loc
//
pool
.
page_size
loc_token_offset_in_page
=
loc
%
pool
.
page_size
flat_buf
=
buf
.
flatten
()
flat_indices
=
(
(
loc_page_index
*
buf_numel_per_page
)[:,
None
]
+
(
loc_token_offset_in_page
*
num_k_bytes_per_token
)[:,
None
]
+
torch
.
arange
(
num_k_bytes_per_token
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[
None
,
:
]
)
num_k_bytes_total
=
num_tokens_to_write
*
num_k_bytes_per_token
flat_indices
=
flat_indices
.
flatten
()[:
num_k_bytes_total
]
flat_buf
[
flat_indices
]
=
index_k
.
view
(
torch
.
uint8
).
flatten
()
class
SetS
:
@
classmethod
def
execute
(
cls
,
*
args
,
buf
,
**
kwargs
):
return
cls
.
torch_fast
(
*
args
,
**
kwargs
,
buf
=
buf
)
@
classmethod
def
slow
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
index_k_scale
:
torch
.
Tensor
,
):
for
i
in
range
(
len
(
loc
)):
page_index
=
loc
[
i
]
//
pool
.
page_size
offset
=
loc
[
i
]
%
pool
.
page_size
start
=
pool
.
page_size
*
pool
.
index_head_dim
buf
[
page_index
,
start
+
offset
*
4
:
start
+
(
offset
+
1
)
*
4
]
=
(
index_k_scale
[
i
].
view
(
torch
.
uint8
)
)
@
classmethod
def
torch_fast
(
cls
,
pool
:
"NSATokenToKVPool"
,
buf
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
index_k_scale
:
torch
.
Tensor
,
):
(
num_tokens_to_write
,)
=
loc
.
shape
buf_numel_per_page
=
buf
.
shape
[
1
]
num_s_bytes_per_token
=
4
s_offset_in_page
=
pool
.
page_size
*
pool
.
index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index
=
loc
//
pool
.
page_size
loc_token_offset_in_page
=
loc
%
pool
.
page_size
flat_buf
=
buf
.
flatten
()
flat_indices
=
(
(
loc_page_index
*
buf_numel_per_page
)[:,
None
]
+
s_offset_in_page
+
(
loc_token_offset_in_page
*
num_s_bytes_per_token
)[:,
None
]
+
torch
.
arange
(
num_s_bytes_per_token
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[
None
,
:
]
)
number_s_bytes_total
=
num_tokens_to_write
*
num_s_bytes_per_token
flat_indices
=
flat_indices
.
flatten
()[:
number_s_bytes_total
]
flat_buf
[
flat_indices
]
=
index_k_scale
.
view
(
torch
.
uint8
).
flatten
()
class
SetKAndS
:
@
classmethod
def
execute
(
cls
,
*
args
,
buf
,
**
kwargs
):
if
0
:
# print("SetK, SetS comparison test")
buf_cloned
=
buf
.
clone
()
cls
.
vanilla
(
*
args
,
**
kwargs
,
buf
=
buf
)
cls
.
triton
(
*
args
,
**
kwargs
,
buf
=
buf_cloned
)
def
_clear_token_0
(
target
):
target
[
0
,
:
128
]
=
target
[
0
,
64
*
128
:
64
*
128
+
4
]
=
0
_clear_token_0
(
buf
)
_clear_token_0
(
buf_cloned
)
assert
torch
.
all
(
buf
==
buf_cloned
),
f
"
{
buf
=
}
{
buf_cloned
=
}
{
kwargs
[
'loc'
].
to_list
()
=
}
"
return
cls
.
triton
(
*
args
,
**
kwargs
,
buf
=
buf
)
@
classmethod
def
vanilla
(
cls
,
pool
,
buf
,
loc
,
index_k
,
index_k_scale
):
SetK
.
execute
(
pool
=
pool
,
buf
=
buf
,
loc
=
loc
,
index_k
=
index_k
)
SetS
.
execute
(
pool
=
pool
,
buf
=
buf
,
loc
=
loc
,
index_k_scale
=
index_k_scale
)
@
classmethod
def
triton
(
cls
,
pool
,
buf
,
loc
,
index_k
,
index_k_scale
):
_set_k_and_s_triton
(
buf
=
buf
,
loc
=
loc
,
index_k
=
index_k
,
index_k_scale
=
index_k_scale
,
page_size
=
pool
.
page_size
,
)
def
_set_k_and_s_triton
(
buf
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
index_k
:
torch
.
Tensor
,
index_k_scale
:
torch
.
Tensor
,
page_size
:
int
,
):
"""
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
:param loc: (num_tokens_to_write,), int, element := the token index to write to
:param index_k: (num_tokens_to_write, 128 elem), fp8
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
:return:
"""
num_pages
,
buf_numel_per_page
=
buf
.
shape
(
num_tokens_to_write
,)
=
loc
.
shape
num_tokens_to_write_
,
index_head_dim
=
index_k
.
shape
num_tokens_to_write__
,
scale_dim
=
index_k_scale
.
shape
assert
buf_numel_per_page
==
64
*
(
128
+
4
)
assert
num_tokens_to_write
==
num_tokens_to_write_
==
num_tokens_to_write__
assert
index_head_dim
==
128
assert
scale_dim
==
1
assert
page_size
==
64
assert
buf
.
dtype
==
torch
.
uint8
assert
loc
.
dtype
==
torch
.
int64
,
f
"
{
loc
.
dtype
=
}
"
# can be int32
assert
index_k
.
dtype
==
torch
.
float8_e4m3fn
assert
index_k_scale
.
dtype
==
torch
.
float32
assert
buf
.
is_contiguous
()
assert
loc
.
is_contiguous
()
assert
index_k
.
is_contiguous
()
assert
index_k_scale
.
is_contiguous
()
buf_fp8
=
buf
.
view
(
torch
.
float8_e4m3fn
)
buf_fp32
=
buf
.
view
(
torch
.
float32
)
_set_k_and_s_triton_kernel
[(
num_tokens_to_write
,)](
buf_fp8
,
buf_fp32
,
loc
,
index_k
,
index_k_scale
,
index_k
.
stride
(
0
),
PAGE_SIZE
=
page_size
,
BUF_NUMEL_PER_PAGE
=
buf_numel_per_page
,
NUM_K_ELEMS_PER_TOKEN
=
index_head_dim
,
S_OFFSET_NBYTES_IN_PAGE
=
page_size
*
index_head_dim
,
)
@
triton
.
jit
def
_set_k_and_s_triton_kernel
(
buf_fp8_ptr
,
buf_fp32_ptr
,
loc_ptr
,
index_k_ptr
,
index_k_scale_ptr
,
index_k_ptr_stride_0
,
PAGE_SIZE
:
tl
.
constexpr
,
BUF_NUMEL_PER_PAGE
:
tl
.
constexpr
,
NUM_K_ELEMS_PER_TOKEN
:
tl
.
constexpr
,
S_OFFSET_NBYTES_IN_PAGE
:
tl
.
constexpr
,
):
token_id
=
tl
.
program_id
(
0
)
loc
=
tl
.
load
(
loc_ptr
+
token_id
)
in_k_offsets
=
token_id
*
index_k_ptr_stride_0
+
tl
.
arange
(
0
,
NUM_K_ELEMS_PER_TOKEN
)
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
k
=
tl
.
load
(
index_k_ptr
+
in_k_offsets
)
k_scale
=
tl
.
load
(
index_k_scale_ptr
+
token_id
)
loc_page_index
=
loc
//
PAGE_SIZE
loc_token_offset_in_page
=
loc
%
PAGE_SIZE
out_k_offsets
=
(
loc_page_index
*
BUF_NUMEL_PER_PAGE
+
loc_token_offset_in_page
*
NUM_K_ELEMS_PER_TOKEN
+
tl
.
arange
(
0
,
NUM_K_ELEMS_PER_TOKEN
)
)
# "//4" b/c it is fp32 instead of uint8
out_s_offset
=
(
loc_page_index
*
BUF_NUMEL_PER_PAGE
//
4
+
S_OFFSET_NBYTES_IN_PAGE
//
4
+
loc_token_offset_in_page
)
tl
.
store
(
buf_fp8_ptr
+
out_k_offsets
,
k
)
tl
.
store
(
buf_fp32_ptr
+
out_s_offset
,
k_scale
)
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
0 → 100644
View file @
852a49c5
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch
import
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.debug_utils.dumper
import
dumper
from
sglang.srt.utils
import
add_prefix
,
is_npu
if
not
is_npu
():
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
import
deep_gemm
from
sglang.srt.layers.attention.nsa.utils
import
NSA_DUAL_STREAM
,
NSA_USE_REAL_INDEXER
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.rotary_embedding
import
get_rope_wrapper
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
add_prefix
,
align
,
is_cuda
try
:
import
deep_gemm_v32
except
ImportError
as
e
:
print
(
"Error when importing deep_gemm_v32, try deep_gemm"
)
try
:
import
deep_gemm
as
deep_gemm_v32
except
ImportError
as
e
:
print
(
"Error when importing deep_gemm, skip"
)
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
NSATokenToKVPool
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
if
is_cuda
()
else
0
class
BaseIndexerMetadata
(
ABC
):
@
abstractmethod
def
get_seqlens_int32
(
self
)
->
torch
.
Tensor
:
"""
Return: (batch_size,) int32 tensor
"""
@
abstractmethod
def
get_page_table_64
(
self
)
->
torch
.
Tensor
:
"""
Return: (batch_size, num_blocks) int32, page table.
The page size of the table is 64.
"""
@
abstractmethod
def
get_seqlens_expanded
(
self
)
->
torch
.
Tensor
:
"""
Return: (sum_extend_seq_len,) int32 tensor
"""
@
abstractmethod
def
topk_transform
(
self
,
logits
:
torch
.
Tensor
,
topk
:
int
,
)
->
torch
.
Tensor
:
"""
Perform topk selection on the logits and possibly transform the result.
NOTE that attention backend may override this function to do some
transformation, which means the result of this topk_transform may not
be the topk indices of the input logits.
Return: Anything, since it will be passed to the attention backend
for further processing on sparse attention computation.
Don't assume it is the topk indices of the input logits.
"""
def
rotate_activation
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
x
.
dtype
==
torch
.
bfloat16
from
fast_hadamard_transform
import
hadamard_transform
hidden_size
=
x
.
size
(
-
1
)
assert
(
hidden_size
&
(
hidden_size
-
1
)
)
==
0
,
"Hidden size must be a power of 2 for Hadamard transform."
return
hadamard_transform
(
x
,
scale
=
hidden_size
**-
0.5
)
class
V32LayerNorm
(
nn
.
Module
):
"""
Layer Normalization.
"""
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
dim
=
dim
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
,
dtype
=
torch
.
float32
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
dim
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
F
.
layer_norm
(
x
.
float
(),
(
self
.
dim
,),
self
.
weight
,
self
.
bias
,
self
.
eps
).
type_as
(
x
)
class
Indexer
(
CustomOp
):
def
__init__
(
self
,
hidden_size
:
int
,
index_n_heads
:
int
,
index_head_dim
:
int
,
rope_head_dim
:
int
,
index_topk
:
int
,
q_lora_rank
:
int
,
max_position_embeddings
:
int
,
rope_theta
:
float
,
layer_id
:
int
,
scale_fmt
:
Optional
[
str
],
block_size
:
int
=
128
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
n_heads
=
index_n_heads
self
.
head_dim
=
index_head_dim
self
.
rope_head_dim
=
rope_head_dim
self
.
index_topk
=
index_topk
self
.
q_lora_rank
=
q_lora_rank
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
if
not
is_npu
():
self
.
sm_count
=
deep_gemm
.
get_num_sms
()
self
.
half_device_sm_count
=
align
(
self
.
sm_count
//
2
,
8
)
self
.
wq_b
=
ReplicatedLinear
(
self
.
q_lora_rank
,
self
.
n_heads
*
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"wq_b"
,
prefix
),
)
self
.
wk
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"wk"
,
prefix
),
)
self
.
k_norm
=
V32LayerNorm
(
self
.
head_dim
)
# NOTE: weight_proj is not quantized
self
.
weights_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
n_heads
,
bias
=
False
,
prefix
=
add_prefix
(
"weights_proj"
,
prefix
),
)
self
.
rotary_emb
=
get_rope_wrapper
(
rope_head_dim
,
rotary_dim
=
rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
# type: ignore
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
device
=
global_server_args_dict
[
"device"
],
)
self
.
block_size
=
block_size
self
.
scale_fmt
=
scale_fmt
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
def
_forward_fake
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
):
bs
=
x
.
shape
[
0
]
assert
self
.
index_topk
==
2048
ans
=
torch
.
arange
(
0
,
self
.
index_topk
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)[
None
,
...
].
repeat
(
bs
,
1
)
if
forward_batch
.
forward_mode
.
is_extend
():
assert
(
forward_batch
.
extend_seq_lens_cpu
is
not
None
and
forward_batch
.
seq_lens_cpu
is
not
None
)
which
=
0
for
i
,
(
kv_len
,
qo_len
)
in
enumerate
(
zip
(
forward_batch
.
seq_lens_cpu
.
tolist
(),
forward_batch
.
extend_seq_lens_cpu
,
strict
=
True
,
)
):
for
j
in
range
(
kv_len
-
qo_len
,
kv_len
):
ans
[
which
,
j
+
1
:]
=
-
1
which
+=
1
assert
which
==
ans
.
shape
[
0
]
else
:
assert
forward_batch
.
seq_lens_cpu
is
not
None
for
i
,
seq_len
in
enumerate
(
forward_batch
.
seq_lens_cpu
.
tolist
()):
ans
[
i
,
seq_len
:]
=
-
1
return
ans
def
_get_logits_head_gate
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
,
_
=
self
.
weights_proj
(
x
)
weights
=
weights
*
self
.
n_heads
**-
0.5
weights
=
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
return
weights
def
_get_q_k_bf16
(
self
,
q_lora
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
enable_dual_stream
:
bool
,
):
if
enable_dual_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
with
deep_gemm_wrapper
.
configure_deep_gemm_num_sms
(
self
.
half_device_sm_count
):
query
,
_
=
self
.
wq_b
(
q_lora
)
query
=
rearrange
(
query
,
"l (h d) -> l h d"
,
d
=
self
.
head_dim
)
q_rope
,
_
=
torch
.
split
(
query
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
,
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
key
,
_
=
self
.
wk
(
x
)
key
=
self
.
k_norm
(
key
)
k_rope
,
_
=
torch
.
split
(
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
,
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
query
,
_
=
self
.
wq_b
(
q_lora
)
if
dumper
.
_enable
:
after_wq_b
=
query
.
clone
()
query
=
rearrange
(
query
,
"l (h d) -> l h d"
,
d
=
self
.
head_dim
)
q_rope
,
_
=
torch
.
split
(
query
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
)
key
,
_
=
self
.
wk
(
x
)
if
dumper
.
_enable
:
after_wk
=
key
.
clone
()
key
=
self
.
k_norm
(
key
)
if
dumper
.
_enable
:
after_k_norm
=
key
.
clone
()
k_rope
,
_
=
torch
.
split
(
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
)
q_rope
,
k_rope
=
self
.
rotary_emb
(
positions
,
q_rope
,
k_rope
)
query
[...,
:
self
.
rope_head_dim
]
=
q_rope
key
[...,
:
self
.
rope_head_dim
]
=
k_rope
if
dumper
.
_enable
:
q_before_hadamard
=
query
.
clone
()
k_before_hadamard
=
key
.
clone
()
if
enable_dual_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
query
=
rotate_activation
(
query
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
key
=
rotate_activation
(
key
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
query
=
rotate_activation
(
query
)
key
=
rotate_activation
(
key
)
return
query
,
key
def
_get_topk_paged
(
self
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
q_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
metadata
:
BaseIndexerMetadata
,
)
->
torch
.
Tensor
:
if
TYPE_CHECKING
:
assert
isinstance
(
forward_batch
.
token_to_kv_pool
,
NSATokenToKVPool
)
page_size
=
forward_batch
.
token_to_kv_pool
.
page_size
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm_v32
assert
page_size
==
64
,
"only support page size 64"
# NOTE(dark): this support extend/decode/decode+graph
block_tables
=
metadata
.
get_page_table_64
()
max_seq_len
=
block_tables
.
shape
[
1
]
*
page_size
kv_cache_fp8
=
forward_batch
.
token_to_kv_pool
.
get_index_k_with_scale_buffer
(
layer_id
=
layer_id
)
blocksize
=
page_size
seqlens_32
=
metadata
.
get_seqlens_int32
()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata
=
deep_gemm_v32
.
get_paged_mqa_logits_metadata
(
seqlens_32
,
blocksize
,
self
.
sm_count
)
assert
len
(
q_fp8
.
shape
)
==
3
q_fp8
=
q_fp8
.
unsqueeze
(
1
)
# the next_n dim is 1 now
assert
len
(
kv_cache_fp8
.
shape
)
==
2
block_kv
=
64
num_heads_kv
=
1
head_dim_with_sf
=
132
kv_cache_fp8
=
kv_cache_fp8
.
view
(
kv_cache_fp8
.
shape
[
0
],
block_kv
,
num_heads_kv
,
head_dim_with_sf
)
assert
len
(
weights
.
shape
)
==
3
weights
=
weights
.
squeeze
(
2
)
logits
=
deep_gemm_v32
.
fp8_paged_mqa_logits
(
q_fp8
,
kv_cache_fp8
,
weights
,
seqlens_32
,
block_tables
,
schedule_metadata
,
max_seq_len
,
clean_logits
=
False
,
)
# NOTE(dark): logits should be cleaned in topk_transform
topk_result
=
metadata
.
topk_transform
(
logits
,
self
.
index_topk
)
return
topk_result
def
_get_topk_ragged
(
self
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
q_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
metadata
:
BaseIndexerMetadata
,
)
->
torch
.
Tensor
:
if
TYPE_CHECKING
:
assert
isinstance
(
forward_batch
.
token_to_kv_pool
,
NSATokenToKVPool
)
page_size
=
forward_batch
.
token_to_kv_pool
.
page_size
assert
page_size
==
64
,
"only support page size 64"
assert
len
(
weights
.
shape
)
==
3
weights
=
weights
.
squeeze
(
-
1
)
k_fp8_list
=
[]
k_scale_list
=
[]
ks_list
=
[]
offset
=
0
block_tables
=
metadata
.
get_page_table_64
()
assert
(
forward_batch
.
seq_lens_cpu
is
not
None
and
forward_batch
.
extend_seq_lens_cpu
is
not
None
)
for
i
in
range
(
forward_batch
.
batch_size
):
seq_len
=
forward_batch
.
seq_lens_cpu
[
i
].
item
()
assert
isinstance
(
seq_len
,
int
)
k_fp8
=
forward_batch
.
token_to_kv_pool
.
get_index_k_continuous
(
layer_id
,
seq_len
,
block_tables
[
i
],
)
k_scale
=
forward_batch
.
token_to_kv_pool
.
get_index_k_scale_continuous
(
layer_id
,
seq_len
,
block_tables
[
i
],
)
extend_seq_len
=
forward_batch
.
extend_seq_lens_cpu
[
i
]
ks
=
torch
.
full
((
extend_seq_len
,),
offset
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
k_fp8_list
.
append
(
k_fp8
)
k_scale_list
.
append
(
k_scale
)
ks_list
.
append
(
ks
)
offset
+=
extend_seq_len
k_fp8
=
torch
.
cat
(
k_fp8_list
,
dim
=
0
).
view
(
torch
.
float8_e4m3fn
)
k_scale
=
torch
.
cat
(
k_scale_list
,
dim
=
0
).
view
(
torch
.
float32
).
squeeze
(
-
1
)
kv_fp8
=
(
k_fp8
,
k_scale
)
ks
=
torch
.
cat
(
ks_list
,
dim
=
0
)
seq_lens_expanded
=
metadata
.
get_seqlens_expanded
()
ke
=
ks
+
seq_lens_expanded
logits
=
deep_gemm_v32
.
fp8_mqa_logits
(
q_fp8
,
kv_fp8
,
weights
,
ks
,
ke
,
clean_logits
=
False
,
)
assert
logits
.
shape
[
0
]
==
len
(
seq_lens_expanded
)
topk_result
=
metadata
.
topk_transform
(
logits
,
self
.
index_topk
)
return
topk_result
def
_forward
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
if
TYPE_CHECKING
:
assert
isinstance
(
forward_batch
.
token_to_kv_pool
,
NSATokenToKVPool
)
metadata
=
forward_batch
.
attn_backend
.
get_indexer_metadata
(
layer_id
,
forward_batch
)
enable_dual_stream
=
(
NSA_DUAL_STREAM
and
self
.
alt_stream
is
not
None
and
get_is_capture_mode
()
and
q_lora
.
shape
[
0
]
>
0
and
q_lora
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
)
# skip NSA if attention backend choose to skip this batch
if
metadata
is
None
:
return
None
if
not
NSA_USE_REAL_INDEXER
:
# temporary
return
self
.
_forward_fake
(
x
,
q_lora
,
positions
,
forward_batch
,
layer_id
)
query
,
key
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
q_fp8
=
query
.
to
(
torch
.
float32
)
k_fp8
=
key
.
to
(
torch
.
float32
)
q_scale
=
torch
.
ones
((
query
.
shape
[
0
],
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
k_scale
=
torch
.
ones
((
key
.
shape
[
0
],
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
enable_dual_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
q_fp8
,
q_scale
=
act_quant
(
query
,
self
.
block_size
,
self
.
scale_fmt
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
k_fp8
,
k_scale
=
act_quant
(
key
,
self
.
block_size
,
self
.
scale_fmt
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
q_fp8
,
q_scale
=
act_quant
(
query
,
self
.
block_size
,
self
.
scale_fmt
)
k_fp8
,
k_scale
=
act_quant
(
key
,
self
.
block_size
,
self
.
scale_fmt
)
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
forward_batch
.
token_to_kv_pool
.
set_index_k_and_scale_buffer
(
layer_id
=
layer_id
,
loc
=
forward_batch
.
out_cache_loc
,
index_k
=
k_fp8
,
index_k_scale
=
k_scale
,
)
weights
=
self
.
_get_logits_head_gate
(
x
,
q_scale
)
assert
forward_batch
.
seq_lens_cpu
is
not
None
if
len
(
forward_batch
.
seq_lens_cpu
)
==
0
:
# this seems b/c max-pad, no worries?
# if x.shape[0] != 0:
# print(
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
# )
return
torch
.
full
(
(
x
.
shape
[
0
],
self
.
index_topk
),
-
1
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
topk_result
=
self
.
_get_topk_paged
(
forward_batch
,
layer_id
,
q_fp8
,
weights
,
metadata
)
else
:
topk_result
=
self
.
_get_topk_ragged
(
forward_batch
,
layer_id
,
q_fp8
,
weights
,
metadata
)
return
topk_result
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_forward
(
x
,
q_lora
,
positions
,
forward_batch
,
layer_id
)
def
forward_npu
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
)
->
torch
.
Tensor
:
import
custom_ops
import
torch_npu
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
)
from
sglang.srt.utils
import
get_bool_env_var
if
forward_batch
.
attn_backend
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
actual_seq_lengths_kv
=
forward_batch
.
attn_backend
.
forward_metadata
.
seq_lens
else
:
actual_seq_lengths_kv
=
(
forward_batch
.
attn_backend
.
forward_metadata
.
seq_lens_cpu_int
)
enable_index_cp
=
(
get_bool_env_var
(
"SGLANG_USE_AG_AFTER_QLORA"
)
and
layer_id
>=
4
)
is_prefill
=
forward_batch
.
forward_mode
.
is_extend
()
attention_tp_rank
=
get_attention_tp_rank
()
attention_tp_size
=
get_attention_tp_size
()
cos_sin
=
self
.
rotary_emb
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
).
view
(
-
1
,
1
,
1
,
self
.
rope_head_dim
)
sin
=
sin
.
repeat
(
1
,
2
).
view
(
-
1
,
1
,
1
,
self
.
rope_head_dim
)
if
is_prefill
and
enable_index_cp
:
slice_length
=
cos
.
shape
[
0
]
//
attention_tp_size
cos
=
cos
[
slice_length
*
attention_tp_rank
:
slice_length
*
(
attention_tp_rank
+
1
)
]
sin
=
sin
[
slice_length
*
attention_tp_rank
:
slice_length
*
(
attention_tp_rank
+
1
)
]
slot_mapping
=
forward_batch
.
out_cache_loc
block_table
=
forward_batch
.
attn_backend
.
forward_metadata
.
block_tables
bs
=
x
.
shape
[
0
]
q
=
self
.
wq_b
(
q_lora
)[
0
]
# [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
q
=
q
.
view
(
bs
,
self
.
n_heads
,
self
.
head_dim
)
# [bs, 64, 128]
q_pe
,
q_nope
=
torch
.
split
(
q
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
,
)
# [bs, 64, 64 + 64]
q_pe
=
q_pe
.
view
(
bs
,
self
.
n_heads
,
1
,
self
.
rope_head_dim
)
q_pe
=
torch_npu
.
npu_interleave_rope
(
q_pe
,
cos
,
sin
).
view
(
bs
,
self
.
n_heads
,
self
.
rope_head_dim
)
# [bs, n, d]
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
k_proj
=
self
.
wk
(
x
)[
0
]
# [b, s, 7168] @ [7168, 128] = [b, s, 128]
k
=
self
.
k_norm
(
k_proj
)
k_pe
,
k_nope
=
torch
.
split
(
k
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
,
)
# [bs, 64 + 64]
k_pe
=
k_pe
.
view
(
-
1
,
1
,
1
,
self
.
rope_head_dim
)
k_pe
=
torch_npu
.
npu_interleave_rope
(
k_pe
,
cos
,
sin
).
view
(
bs
,
1
,
self
.
rope_head_dim
)
# [bs, 1, d]
k
=
torch
.
cat
([
k_pe
,
k_nope
.
unsqueeze
(
1
)],
dim
=-
1
)
# [bs, 1, 128]
if
is_prefill
and
enable_index_cp
:
k
,
local_k
=
(
torch
.
empty
(
(
k
.
shape
[
0
]
*
attention_tp_size
,
k
.
shape
[
1
],
k
.
shape
[
2
]),
dtype
=
k
.
dtype
,
device
=
k
.
device
,
),
k
,
)
get_attention_tp_group
().
all_gather_into_tensor
(
k
,
local_k
)
forward_batch
.
token_to_kv_pool
.
set_index_k_buffer
(
layer_id
,
slot_mapping
,
k
)
indexer_input
=
{}
if
is_prefill
:
actual_seq_lengths_kv
=
forward_batch
.
seq_lens
.
to
(
device
=
q
.
device
)
actual_seq_lengths_q
=
forward_batch
.
seq_lens
.
cumsum
(
dim
=
0
).
to
(
device
=
q
.
device
)
if
enable_index_cp
:
actual_seq_lengths_q
-=
bs
*
attention_tp_rank
actual_seq_lengths_q
=
torch
.
max
(
actual_seq_lengths_q
,
torch
.
zeros_like
(
actual_seq_lengths_q
).
to
(
device
=
actual_seq_lengths_q
.
device
),
)
actual_seq_lengths_q
=
torch
.
min
(
actual_seq_lengths_q
,
torch
.
full
(
actual_seq_lengths_q
.
shape
,
bs
).
to
(
device
=
actual_seq_lengths_q
.
device
),
)
else
:
if
forward_batch
.
attn_backend
.
forward_metadata
.
actual_seq_lengths_q
is
None
:
actual_seq_lengths_q
=
torch
.
tensor
(
[
1
+
i
*
1
for
i
in
range
(
bs
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
)
else
:
actual_seq_lengths_q
=
(
forward_batch
.
attn_backend
.
forward_metadata
.
actual_seq_lengths_q
)
past_key_states
=
forward_batch
.
token_to_kv_pool
.
get_index_k_buffer
(
layer_id
)
x
=
x
.
view
(
-
1
,
self
.
hidden_size
)
weights
=
self
.
weights_proj
(
x
)[
0
]
block_table
=
(
block_table
[:
actual_seq_lengths_q
.
size
()[
0
]]
if
is_prefill
else
block_table
)
topk_indices
=
torch
.
ops
.
custom
.
npu_lightning_indexer
(
query
=
q
.
view
(
-
1
,
self
.
n_heads
,
self
.
head_dim
),
key
=
past_key_states
,
weights
=
weights
,
actual_seq_lengths_query
=
actual_seq_lengths_q
.
to
(
torch
.
int32
),
actual_seq_lengths_key
=
actual_seq_lengths_kv
.
to
(
k
.
device
).
to
(
torch
.
int32
),
block_table
=
block_table
,
layout_query
=
"TND"
,
layout_key
=
"PA_BSND"
,
sparse_count
=
self
.
index_topk
,
sparse_mode
=
3
,
)
if
is_prefill
and
enable_index_cp
:
topk_indices
,
local_topk_indices
=
(
torch
.
empty
(
(
topk_indices
.
shape
[
0
]
*
attention_tp_size
,
topk_indices
.
shape
[
1
],
topk_indices
.
shape
[
2
],
),
dtype
=
topk_indices
.
dtype
,
device
=
topk_indices
.
device
,
),
topk_indices
,
)
get_attention_tp_group
().
all_gather_into_tensor
(
topk_indices
,
local_topk_indices
)
return
topk_indices
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
0 → 100644
View file @
852a49c5
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.nsa.utils
import
NSA_QUANT_K_CACHE_FAST
def
quantize_k_cache
(
cache_k
):
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
if
NSA_QUANT_K_CACHE_FAST
:
return
_quantize_k_cache_fast_wrapped
(
cache_k
)
else
:
return
_quantize_k_cache_slow
(
cache_k
)
# Copied from original
def
_quantize_k_cache_slow
(
input_k_cache
:
torch
.
Tensor
,
# (num_blocks, block_size, h_k, d)
dv
:
int
=
512
,
tile_size
:
int
=
128
,
)
->
torch
.
Tensor
:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert
dv
%
tile_size
==
0
num_tiles
=
dv
//
tile_size
num_blocks
,
block_size
,
h_k
,
d
=
input_k_cache
.
shape
assert
h_k
==
1
input_k_cache
=
input_k_cache
.
squeeze
(
2
)
# [num_blocks, block_size, d]
input_elem_size
=
input_k_cache
.
element_size
()
result
=
torch
.
empty
(
(
num_blocks
,
block_size
,
dv
+
num_tiles
*
4
+
input_elem_size
*
(
d
-
dv
)),
dtype
=
torch
.
float8_e4m3fn
,
device
=
input_k_cache
.
device
,
)
result_k_nope_part
=
result
[...,
:
dv
]
result_k_scale_factor
=
result
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
result_k_rope_part
=
result
[...,
dv
+
num_tiles
*
4
:].
view
(
input_k_cache
.
dtype
)
result_k_rope_part
[:]
=
input_k_cache
[...,
dv
:]
for
tile_idx
in
range
(
0
,
num_tiles
):
cur_scale_factors_inv
=
(
torch
.
abs
(
input_k_cache
[...,
tile_idx
*
tile_size
:
(
tile_idx
+
1
)
*
tile_size
]
)
.
max
(
dim
=-
1
)
.
values
/
448.0
)
# [num_blocks, block_size]
result_k_scale_factor
[:,
:,
tile_idx
]
=
cur_scale_factors_inv
cur_scale_factors_inv
.
unsqueeze_
(
-
1
)
# [num_blocks, block_size, 1]
cur_quantized_nope
=
(
input_k_cache
[
...,
tile_idx
*
tile_size
:
(
tile_idx
+
1
)
*
tile_size
].
float
()
/
cur_scale_factors_inv
.
float
()
).
to
(
torch
.
float8_e4m3fn
)
result_k_nope_part
[...,
tile_idx
*
tile_size
:
(
tile_idx
+
1
)
*
tile_size
]
=
(
cur_quantized_nope
)
result
=
result
.
view
(
num_blocks
,
block_size
,
1
,
-
1
)
return
result
def
_quantize_k_cache_fast_wrapped
(
input_k_cache
:
torch
.
Tensor
,
dv
:
int
=
512
,
tile_size
:
int
=
128
,
)
->
torch
.
Tensor
:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks
,
block_size
,
_
,
dim_nope_and_rope
=
input_k_cache
.
shape
assert
dv
==
512
assert
dim_nope_and_rope
==
512
+
64
assert
tile_size
==
128
input_k_cache
=
input_k_cache
.
view
((
-
1
,
dim_nope_and_rope
))
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
k_nope
=
input_k_cache
[:,
:
dv
]
k_rope
=
input_k_cache
[:,
dv
:]
output
=
_quantize_k_cache_fast
(
k_nope
=
k_nope
,
k_rope
=
k_rope
)
return
output
.
view
(
num_blocks
,
block_size
,
1
,
-
1
)
def
_quantize_k_cache_fast
(
k_nope
,
k_rope
,
group_size
:
int
=
128
):
"""
:param k_nope: (num_tokens, dim_nope 512)
:param k_rope: (num_tokens, dim_rope 64)
"""
assert
k_nope
.
dtype
==
torch
.
bfloat16
assert
k_rope
.
dtype
==
torch
.
bfloat16
num_tokens
,
dim_nope
=
k_nope
.
shape
num_tokens_
,
dim_rope
=
k_rope
.
shape
assert
num_tokens
==
num_tokens_
assert
dim_nope
==
512
assert
dim_rope
==
64
assert
k_nope
.
dtype
==
k_rope
.
dtype
num_tiles
=
dim_nope
//
group_size
assert
k_nope
.
stride
(
1
)
==
1
assert
k_rope
.
stride
(
1
)
==
1
output
=
torch
.
empty
(
(
num_tokens
,
dim_nope
+
num_tiles
*
4
+
k_rope
.
element_size
()
*
dim_rope
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
k_nope
.
device
,
)
output_nope_q
=
output
[...,
:
dim_nope
]
output_nope_s
=
output
[...,
dim_nope
:
dim_nope
+
num_tiles
*
4
].
view
(
torch
.
float32
)
output_rope
=
output
[...,
dim_nope
+
num_tiles
*
4
:].
view
(
torch
.
bfloat16
)
num_blocks_per_token
=
triton
.
cdiv
(
dim_nope
+
dim_rope
,
group_size
)
assert
num_blocks_per_token
==
5
assert
dim_nope
%
group_size
==
0
NUM_NOPE_BLOCKS
=
dim_nope
//
group_size
_quantize_k_cache_fast_kernel
[(
num_tokens
,
num_blocks_per_token
)](
output_nope_q
,
output_nope_s
,
output_rope
,
k_nope
,
k_rope
,
output_nope_q
.
stride
(
0
),
output_nope_s
.
stride
(
0
),
output_rope
.
stride
(
0
),
k_nope
.
stride
(
0
),
k_rope
.
stride
(
0
),
NUM_NOPE_BLOCKS
=
NUM_NOPE_BLOCKS
,
GROUP_SIZE
=
group_size
,
DIM_NOPE
=
dim_nope
,
DIM_ROPE
=
dim_rope
,
FP8_MIN
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
,
FP8_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
,
)
return
output
@
triton
.
jit
def
_quantize_k_cache_fast_kernel
(
output_nope_q_ptr
,
output_nope_s_ptr
,
output_rope_ptr
,
k_nope_ptr
,
k_rope_ptr
,
output_nope_q_stride_0
:
int
,
output_nope_s_stride_0
:
int
,
output_rope_stride_0
:
int
,
k_nope_stride_0
:
int
,
k_rope_stride_0
:
int
,
NUM_NOPE_BLOCKS
:
tl
.
constexpr
,
GROUP_SIZE
:
tl
.
constexpr
,
DIM_NOPE
:
tl
.
constexpr
,
DIM_ROPE
:
tl
.
constexpr
,
FP8_MIN
:
tl
.
constexpr
,
FP8_MAX
:
tl
.
constexpr
,
):
token_id
=
tl
.
program_id
(
0
)
raw_block_id
=
tl
.
program_id
(
1
)
if
raw_block_id
<
NUM_NOPE_BLOCKS
:
# a. quant nope
effective_block_id
=
raw_block_id
offs
=
effective_block_id
*
GROUP_SIZE
+
tl
.
arange
(
0
,
GROUP_SIZE
)
mask
=
offs
<
DIM_NOPE
ptr
=
k_nope_ptr
+
token_id
*
k_nope_stride_0
+
offs
y
=
tl
.
load
(
ptr
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
y_s
=
tl
.
max
(
tl
.
abs
(
y
))
/
FP8_MAX
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
FP8_MIN
,
FP8_MAX
).
to
(
output_nope_q_ptr
.
dtype
.
element_ty
)
dst_q_ptr
=
output_nope_q_ptr
+
token_id
*
output_nope_q_stride_0
+
offs
dst_s_ptr
=
(
output_nope_s_ptr
+
token_id
*
output_nope_s_stride_0
+
effective_block_id
)
tl
.
store
(
dst_q_ptr
,
y_q
,
mask
=
mask
)
tl
.
store
(
dst_s_ptr
,
y_s
)
else
:
# b. copy rope
effective_block_id
=
raw_block_id
-
NUM_NOPE_BLOCKS
offs
=
effective_block_id
*
GROUP_SIZE
+
tl
.
arange
(
0
,
GROUP_SIZE
)
mask
=
offs
<
DIM_ROPE
src_ptr
=
k_rope_ptr
+
token_id
*
k_rope_stride_0
+
offs
dst_ptr
=
output_rope_ptr
+
token_id
*
output_rope_stride_0
+
offs
data
=
tl
.
load
(
src_ptr
,
mask
=
mask
)
tl
.
store
(
dst_ptr
,
data
,
mask
=
mask
)
if
__name__
==
"__main__"
:
for
num_blocks
,
block_size
in
[
(
1
,
1
),
(
10
,
64
),
]:
dim_nope_and_rope
=
512
+
64
input_k_cache
=
torch
.
randn
(
(
num_blocks
,
block_size
,
1
,
dim_nope_and_rope
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
,
)
# temp debug
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
ref_quant
=
_quantize_k_cache_slow
(
input_k_cache
)
actual_quant
=
_quantize_k_cache_fast_wrapped
(
input_k_cache
)
# print(f"{input_k_cache=}")
# print(f"{ref_quant=}")
# print(f"{actual_quant=}")
# print(f"{ref_quant == actual_quant=}")
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
# print(f"{ref_quant.view(torch.bfloat16)=}")
# print(f"{actual_quant.view(torch.bfloat16)=}")
# assert torch.all(ref_quant == actual_quant)
import
dequant_k_cache
ref_ref_dequant
=
dequant_k_cache
.
_dequantize_k_cache_slow
(
ref_quant
)
ref_actual_dequant
=
dequant_k_cache
.
_dequantize_k_cache_fast_wrapped
(
ref_quant
)
actual_actual_dequant
=
dequant_k_cache
.
_dequantize_k_cache_fast_wrapped
(
actual_quant
)
print
(
f
"
{
ref_ref_dequant
=
}
"
)
print
(
f
"
{
actual_actual_dequant
=
}
"
)
print
(
f
"
{
actual_actual_dequant
-
ref_ref_dequant
=
}
"
)
print
(
f
"
{
torch
.
mean
(
ref_ref_dequant
-
actual_actual_dequant
)
=
}
"
)
# TODO too different?
torch
.
testing
.
assert_close
(
ref_ref_dequant
,
ref_actual_dequant
,
atol
=
0.2
,
rtol
=
0.2
)
torch
.
testing
.
assert_close
(
ref_ref_dequant
,
actual_actual_dequant
,
atol
=
0.2
,
rtol
=
0.2
)
print
(
"Passed"
)
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
0 → 100644
View file @
852a49c5
from
typing
import
Optional
,
Tuple
import
tilelang
import
tilelang.language
as
T
import
torch
tilelang
.
set_log_level
(
"WARNING"
)
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_FAST_MATH
:
True
,
}
BF16
=
"bfloat16"
FP8
=
"float8_e4m3"
FP32
=
"float32"
def
fast_log2_ceil
(
x
):
bits_x
=
T
.
reinterpret
(
"uint32"
,
x
)
exp_x
=
(
bits_x
>>
23
)
&
0xFF
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
)
return
T
.
Cast
(
"int32"
,
exp_x
-
127
+
T
.
if_then_else
(
man_bits
!=
0
,
1
,
0
))
def
fast_pow2
(
x
):
bits_x
=
(
x
+
127
)
<<
23
return
T
.
reinterpret
(
"float32"
,
bits_x
)
def
fast_round_scale
(
amax
,
fp8_max_inv
):
return
fast_pow2
(
fast_log2_ceil
(
amax
*
fp8_max_inv
))
@
tilelang
.
jit
(
pass_configs
=
pass_configs
)
def
act_quant_kernel
(
N
,
in_dtype
=
BF16
,
out_dtype
=
FP8
,
scale_dtype
=
FP32
,
round_scale
=
False
):
M
=
T
.
symbolic
(
"M"
)
fp8_min
=
-
448.0
fp8_max
=
448.0
fp8_max_inv
=
1
/
fp8_max
num_stages
=
0
if
round_scale
else
2
blk_m
=
32
group_size
=
128
@
T
.
prim_func
def
act_quant_kernel_
(
X
:
T
.
Tensor
[(
M
,
N
),
in_dtype
],
Y
:
T
.
Tensor
[(
M
,
N
),
out_dtype
],
S
:
T
.
Tensor
[(
M
,
T
.
ceildiv
(
N
,
group_size
)),
scale_dtype
],
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
pid_m
,
pid_n
,
):
x_shared
=
T
.
alloc_shared
((
blk_m
,
group_size
),
in_dtype
)
x_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
in_dtype
)
amax_local
=
T
.
alloc_fragment
((
blk_m
,),
scale_dtype
)
s_local
=
T
.
alloc_fragment
((
blk_m
,),
scale_dtype
)
y_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
out_dtype
)
y_shared
=
T
.
alloc_shared
((
blk_m
,
group_size
),
out_dtype
)
for
_
in
T
.
Pipelined
(
1
,
num_stages
=
num_stages
):
T
.
copy
(
X
[
pid_m
*
blk_m
,
pid_n
*
group_size
],
x_shared
)
T
.
copy
(
x_shared
,
x_local
)
T
.
reduce_absmax
(
x_local
,
amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
amax_local
[
i
]
=
T
.
max
(
amax_local
[
i
],
1e-4
)
if
round_scale
:
s_local
[
i
]
=
fast_round_scale
(
amax_local
[
i
],
fp8_max_inv
)
else
:
s_local
[
i
]
=
amax_local
[
i
]
*
fp8_max_inv
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_local
[
i
,
j
]
=
T
.
clamp
(
x_local
[
i
,
j
]
/
s_local
[
i
],
fp8_min
,
fp8_max
)
for
i
in
T
.
Parallel
(
blk_m
):
S
[
pid_m
*
blk_m
+
i
,
pid_n
]
=
s_local
[
i
]
T
.
copy
(
y_local
,
y_shared
)
T
.
copy
(
y_shared
,
Y
[
pid_m
*
blk_m
,
pid_n
*
group_size
])
return
act_quant_kernel_
def
act_quant
(
x
:
torch
.
Tensor
,
block_size
:
int
=
128
,
scale_fmt
:
Optional
[
str
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert
x
.
is_contiguous
(),
"Input tensor must be contiguous"
assert
(
x
.
size
(
-
1
)
%
block_size
==
0
),
f
"Last dimension size must be divisible by block_size (block_size=
{
block_size
}
)"
N
=
x
.
size
(
-
1
)
y
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
float8_e4m3fn
)
s
=
x
.
new_empty
(
*
x
.
size
()[:
-
1
],
N
//
block_size
,
dtype
=
torch
.
float32
)
kernel
=
act_quant_kernel
(
N
,
round_scale
=
scale_fmt
is
not
None
)
kernel
(
x
.
view
(
-
1
,
N
),
y
.
view
(
-
1
,
N
),
s
.
view
(
-
1
,
N
//
block_size
))
return
y
,
s
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
pass_configs
)
def
fp8_index_kernel
(
h
:
int
,
d
:
int
):
b
=
T
.
symbolic
(
"b"
)
m
=
T
.
symbolic
(
"m"
)
n
=
T
.
symbolic
(
"n"
)
blk_n1
=
512
blk_n2
=
128
@
T
.
prim_func
def
fp8_index_kernel_
(
q
:
T
.
Tensor
[(
b
,
m
,
h
,
d
),
FP8
],
q_s
:
T
.
Tensor
[(
b
,
m
,
h
),
FP32
],
k
:
T
.
Tensor
[(
b
,
n
,
d
),
FP8
],
k_s
:
T
.
Tensor
[(
b
,
n
),
FP32
],
o
:
T
.
Tensor
[(
b
,
m
,
n
),
FP32
],
)
->
None
:
with
T
.
Kernel
(
b
,
m
,
T
.
ceildiv
(
n
,
blk_n1
))
as
(
i_b
,
i_m
,
i1_n
):
q_smem
=
T
.
alloc_shared
((
h
,
d
),
FP8
)
T
.
copy
(
q
[
i_b
,
i_m
,
0
,
0
],
q_smem
)
q_s_frag
=
T
.
alloc_fragment
(
h
,
FP32
)
T
.
copy
(
q_s
[
i_b
,
i_m
,
0
],
q_s_frag
)
for
i2_n
in
T
.
Pipelined
(
blk_n1
//
blk_n2
,
num_stages
=
2
):
k_smem
=
T
.
alloc_shared
((
blk_n2
,
d
),
FP8
)
T
.
copy
(
k
[
i_b
,
i1_n
*
blk_n1
+
i2_n
*
blk_n2
,
0
],
k_smem
)
k_s_frag
=
T
.
alloc_fragment
(
blk_n2
,
FP32
)
T
.
copy
(
k_s
[
i_b
,
i1_n
*
blk_n1
+
i2_n
*
blk_n2
],
k_s_frag
)
logits
=
T
.
alloc_fragment
((
blk_n2
,
h
),
FP32
)
T
.
gemm
(
k_smem
,
q_smem
,
logits
,
transpose_A
=
False
,
transpose_B
=
True
,
clear_accum
=
True
,
)
for
i_h
,
i3_n
in
T
.
Parallel
(
h
,
blk_n2
):
logits
[
i3_n
,
i_h
]
=
T
.
max
(
logits
[
i3_n
,
i_h
],
0
)
*
q_s_frag
[
i_h
]
logits_sum
=
T
.
alloc_fragment
(
blk_n2
,
FP32
)
T
.
reduce_sum
(
logits
,
logits_sum
,
dim
=
1
)
for
i3_n
in
T
.
Parallel
(
blk_n2
):
logits_sum
[
i3_n
]
*=
k_s_frag
[
i3_n
]
T
.
copy
(
logits_sum
,
o
[
i_b
,
i_m
,
i1_n
*
blk_n1
+
i2_n
*
blk_n2
])
return
fp8_index_kernel_
def
fp8_index
(
q
:
torch
.
Tensor
,
q_s
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k_s
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
return
fp8_index_kernel
(
q
.
shape
[
2
],
q
.
shape
[
3
])(
q
,
q_s
,
k
,
k_s
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
},
)
def
sparse_attention_fwd_kernel_v1
(
num_heads
,
dim
,
tail_dim
,
topk
,
*
,
kv_group
=
1
,
sm_scale
=
None
,
is_causal
=
True
,
block_I
=
64
,
num_stages
=
2
,
threads
=
256
,
):
assert
dim
==
tilelang
.
math
.
next_power_of_2
(
dim
),
f
"haven't check padding correctness yet, dim=
{
dim
}
"
assert
tail_dim
==
tilelang
.
math
.
next_power_of_2
(
tail_dim
),
f
"haven't check padding correctness yet, dim=
{
tail_dim
}
"
assert
is_causal
==
True
,
"non-casual is not supported"
assert
(
topk
%
block_I
==
0
),
"otherwise will load some index=0 thus causing wrong kv to be loaded"
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
(
dim
+
tail_dim
))
**
0.5
*
1.44269504
# log2(e)
else
:
sm_scale
=
sm_scale
*
1.44269504
# log2(e)
batch
=
T
.
symbolic
(
"batch"
)
seq_len
=
T
.
symbolic
(
"seq_len"
)
seq_len_kv
=
T
.
symbolic
(
"seq_len_kv"
)
head_kv
=
num_heads
//
kv_group
q_shape
=
[
batch
,
seq_len
,
num_heads
,
dim
+
tail_dim
]
kv_shape
=
[
batch
,
seq_len_kv
,
kv_group
,
dim
+
tail_dim
]
o_shape
=
[
batch
,
seq_len
,
num_heads
,
dim
]
indices_shape
=
[
batch
,
seq_len
,
kv_group
,
topk
]
indices_dtype
=
"int32"
dtype
=
"bfloat16"
accum_dtype
=
"float"
H
=
head_kv
padded_H
=
max
(
tilelang
.
math
.
next_power_of_2
(
head_kv
),
16
)
if
padded_H
!=
H
:
assert
kv_group
==
1
BI
=
block_I
NI
=
tilelang
.
cdiv
(
topk
,
block_I
)
D
=
dim
D_tail
=
tail_dim
if
head_kv
>
64
:
assert
head_kv
%
64
==
0
,
"head_kv should be a multiple of 64"
REPLICATE_H
=
head_kv
//
64
else
:
REPLICATE_H
=
1
H_per_block
=
padded_H
if
REPLICATE_H
==
1
else
64
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
KV
:
T
.
Tensor
(
kv_shape
,
dtype
),
# type: ignore
Indices
:
T
.
Tensor
(
indices_shape
,
indices_dtype
),
# type: ignore
Output
:
T
.
Tensor
(
o_shape
,
dtype
),
# type: ignore
):
with
T
.
Kernel
(
seq_len
*
REPLICATE_H
,
batch
,
kv_group
,
threads
=
threads
)
as
(
bx
,
by
,
bz
,
):
Q_shared
=
T
.
alloc_shared
([
H_per_block
,
D
],
dtype
)
Q_tail_shared
=
T
.
alloc_shared
([
H_per_block
,
D_tail
],
dtype
)
KV_shared
=
T
.
alloc_shared
([
BI
,
D
],
dtype
)
K_tail_shared
=
T
.
alloc_shared
([
BI
,
D_tail
],
dtype
)
O_shared
=
T
.
alloc_shared
([
H_per_block
,
D
],
dtype
)
mask
=
T
.
alloc_fragment
([
BI
],
"bool"
)
acc_o
=
T
.
alloc_fragment
([
H_per_block
,
D
],
accum_dtype
)
acc_s
=
T
.
alloc_fragment
([
H_per_block
,
BI
],
accum_dtype
)
S_shared
=
T
.
alloc_shared
([
H_per_block
,
BI
],
dtype
)
sumexp
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
sumexp_i
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
alpha
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
m_i
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
m_i_prev
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
sumexp
,
0
)
T
.
fill
(
m_i
,
-
(
2
**
30
))
# avoid -inf - inf to cause nan
b_i
,
g_i
=
by
,
bz
s_i
=
bx
if
REPLICATE_H
==
1
else
(
bx
//
REPLICATE_H
)
q_i
=
s_i
max_kv_i
=
q_i
H0
=
g_i
*
padded_H
+
(
0
if
REPLICATE_H
==
1
else
(
bx
%
REPLICATE_H
)
*
64
)
H1
=
H0
+
H_per_block
T
.
copy
(
Q
[
b_i
,
s_i
,
H0
:
H1
,
:
D
],
Q_shared
)
T
.
copy
(
Q
[
b_i
,
s_i
,
H0
:
H1
,
D
:],
Q_tail_shared
)
for
i_i
in
T
.
Pipelined
(
NI
,
num_stages
=
num_stages
):
for
bi_i
in
T
.
Parallel
(
BI
):
mask
[
bi_i
]
=
Indices
[
b_i
,
s_i
,
g_i
,
i_i
*
BI
+
bi_i
]
>=
0
for
bi_i
,
d_i
in
T
.
Parallel
(
BI
,
D
):
KV_shared
[
bi_i
,
d_i
]
=
KV
[
b_i
,
Indices
[
b_i
,
s_i
,
g_i
,
i_i
*
BI
+
bi_i
],
g_i
,
d_i
]
for
bi_i
,
d_i
in
T
.
Parallel
(
BI
,
D_tail
):
K_tail_shared
[
bi_i
,
d_i
]
=
KV
[
b_i
,
Indices
[
b_i
,
s_i
,
g_i
,
i_i
*
BI
+
bi_i
],
g_i
,
D
+
d_i
]
for
h_i
,
bi_i
in
T
.
Parallel
(
H_per_block
,
BI
):
acc_s
[
h_i
,
bi_i
]
=
T
.
if_then_else
(
mask
[
bi_i
],
0
,
-
T
.
infinity
(
acc_s
.
dtype
)
)
T
.
gemm
(
Q_shared
,
KV_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
)
T
.
gemm
(
Q_tail_shared
,
K_tail_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
,
)
T
.
copy
(
m_i
,
m_i_prev
)
T
.
reduce_max
(
acc_s
,
m_i
,
dim
=
1
,
clear
=
False
)
for
h_i
in
T
.
Parallel
(
H_per_block
):
alpha
[
h_i
]
=
T
.
exp2
((
m_i_prev
[
h_i
]
-
m_i
[
h_i
])
*
sm_scale
)
for
h_i
,
bi_i
in
T
.
Parallel
(
H_per_block
,
BI
):
acc_s
[
h_i
,
bi_i
]
=
T
.
exp2
(
acc_s
[
h_i
,
bi_i
]
*
sm_scale
-
m_i
[
h_i
]
*
sm_scale
)
T
.
reduce_sum
(
acc_s
,
sumexp_i
,
dim
=
1
)
# is this a accumulate operator?
for
h_i
in
T
.
Parallel
(
H_per_block
):
sumexp
[
h_i
]
=
sumexp
[
h_i
]
*
alpha
[
h_i
]
+
sumexp_i
[
h_i
]
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
):
acc_o
[
h_i
,
d_i
]
=
acc_o
[
h_i
,
d_i
]
*
alpha
[
h_i
]
T
.
copy
(
acc_s
,
S_shared
)
T
.
gemm
(
S_shared
,
KV_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullCol
)
# Rescale
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
):
acc_o
[
h_i
,
d_i
]
/=
sumexp
[
h_i
]
for
h_i
in
T
.
Parallel
(
H_per_block
):
sumexp
[
h_i
]
=
T
.
log2
(
sumexp
[
h_i
])
+
m_i
[
h_i
]
*
sm_scale
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
Output
[
b_i
,
s_i
,
H0
:
H1
,
:])
return
main
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
compile_flags
=
[
"-O3"
,
"-Wno-deprecated-declarations"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--ptxas-options=-v,--register-usage-level=10"
,
"-DNDEBUG"
,
],
)
# type: ignore
def
sparse_attention_fwd_kernel_v2
(
num_heads
:
int
,
dim
:
int
,
tail_dim
:
int
,
topk
:
int
,
*
,
kv_group
:
int
=
1
,
sm_scale
:
Optional
[
float
]
=
None
,
block_I
:
int
=
64
,
):
assert
dim
==
tilelang
.
math
.
next_power_of_2
(
dim
),
f
"haven't check padding correctness yet, dim=
{
dim
}
"
assert
tail_dim
==
tilelang
.
math
.
next_power_of_2
(
tail_dim
),
f
"haven't check padding correctness yet, dim=
{
tail_dim
}
"
assert
(
topk
%
block_I
==
0
),
"otherwise will load some index=0 thus causing wrong kv to be loaded"
if
sm_scale
is
None
:
sm_scale
=
(
1.0
/
(
dim
+
tail_dim
))
**
0.5
*
1.44269504
# log2(e)
else
:
sm_scale
=
sm_scale
*
1.44269504
# log2(e)
threads
=
384
batch
=
T
.
symbolic
(
"batch"
)
qo_len
=
T
.
symbolic
(
"seq_len"
)
num_pages
=
T
.
symbolic
(
"num_pages"
)
q_shape
=
[
batch
,
qo_len
,
num_heads
,
dim
+
tail_dim
]
kv_shape
=
[
batch
,
num_pages
,
kv_group
,
dim
+
tail_dim
]
o_shape
=
[
batch
,
qo_len
,
num_heads
,
dim
]
indices_shape
=
[
batch
,
qo_len
,
kv_group
,
topk
]
indices_dtype
=
"int32"
dtype
=
"bfloat16"
accum_dtype
=
"float"
H
=
num_heads
padded_H
=
max
(
tilelang
.
math
.
next_power_of_2
(
num_heads
),
16
)
if
padded_H
!=
H
:
assert
kv_group
==
1
BI
=
block_I
NI
=
tilelang
.
cdiv
(
topk
,
block_I
)
assert
NI
%
2
==
0
,
"NI should be a multiple of 2"
D
=
dim
D_tail
=
tail_dim
if
num_heads
>
64
:
assert
num_heads
%
64
==
0
,
"head_kv should be a multiple of 64"
REPLICATE_H
=
num_heads
//
64
else
:
REPLICATE_H
=
1
H_per_block
=
padded_H
if
REPLICATE_H
==
1
else
64
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
q_shape
,
dtype
),
# type: ignore
KV
:
T
.
Tensor
(
kv_shape
,
dtype
),
# type: ignore
Indices
:
T
.
Tensor
(
indices_shape
,
indices_dtype
),
# type: ignore
Output
:
T
.
Tensor
(
o_shape
,
dtype
),
# type: ignore
):
"""
Q: [b, qo_len, H, D + D_tail] (bfloat16)
KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
Indices: [b, qo_len, kv_group, topk] (int32)
"""
with
T
.
Kernel
(
qo_len
*
REPLICATE_H
,
batch
,
1
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
# type: ignore
Q_shared_l
=
T
.
alloc_shared
([
H_per_block
,
D
//
2
],
dtype
)
Q_shared_r
=
T
.
alloc_shared
([
H_per_block
,
D
//
2
],
dtype
)
Q_tail_shared
=
T
.
alloc_shared
([
H_per_block
,
D_tail
],
dtype
)
KV_shared_0_l
=
T
.
alloc_shared
([
BI
,
D
//
2
],
dtype
)
KV_shared_0_r
=
T
.
alloc_shared
([
BI
,
D
//
2
],
dtype
)
KV_shared_1_l
=
T
.
alloc_shared
([
BI
,
D
//
2
],
dtype
)
KV_shared_1_r
=
T
.
alloc_shared
([
BI
,
D
//
2
],
dtype
)
K_tail_shared_0
=
T
.
alloc_shared
([
BI
,
D_tail
],
dtype
)
K_tail_shared_1
=
T
.
alloc_shared
([
BI
,
D_tail
],
dtype
)
O_shared_l
=
Q_shared_l
O_shared_r
=
Q_shared_r
is_kv_valid_0
=
T
.
alloc_shared
([
BI
],
"bool"
,
scope
=
"shared"
)
is_kv_valid_1
=
T
.
alloc_shared
([
BI
],
"bool"
,
scope
=
"shared"
)
acc_o_l
=
T
.
alloc_fragment
([
H_per_block
,
D
//
2
],
accum_dtype
)
acc_o_r
=
T
.
alloc_fragment
([
H_per_block
,
D
//
2
],
accum_dtype
)
acc_s
=
T
.
alloc_fragment
([
H_per_block
,
BI
],
accum_dtype
)
S_shared
=
T
.
alloc_shared
([
H_per_block
,
BI
],
dtype
)
sumexp
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
sum_exp_shared
=
T
.
alloc_shared
([
H_per_block
],
accum_dtype
)
sumexp_i
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
alpha_shared
=
T
.
alloc_shared
([
H_per_block
],
accum_dtype
,
scope
=
"shared"
)
alpha_local
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
m_i
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
m_i_prev
=
T
.
alloc_fragment
([
H_per_block
],
accum_dtype
)
indices_local
=
T
.
alloc_local
([
1
],
indices_dtype
)
indices_tmp
=
T
.
alloc_local
([
1
],
indices_dtype
)
bar_q
=
T
.
alloc_barrier
(
arrive_count
=
384
)
bar_k_0_ready
=
T
.
alloc_barrier
(
arrive_count
=
128
)
bar_k_1_ready
=
T
.
alloc_barrier
(
arrive_count
=
128
)
bar_k_0_free
=
T
.
alloc_barrier
(
arrive_count
=
256
)
bar_k_1_free
=
T
.
alloc_barrier
(
arrive_count
=
256
)
bar_sScale_and_sS_ready
=
T
.
alloc_barrier
(
arrive_count
=
256
)
bar_sScale_and_sS_free
=
T
.
alloc_barrier
(
arrive_count
=
256
)
bar_0_128
=
T
.
alloc_barrier
(
arrive_count
=
128
)
bar_1_128
=
T
.
alloc_barrier
(
arrive_count
=
128
)
bar_2_128
=
T
.
alloc_barrier
(
arrive_count
=
128
)
bar_final
=
T
.
alloc_barrier
(
arrive_count
=
128
)
b_i
,
g_i
=
by
,
bz
s_i
=
bx
if
REPLICATE_H
==
1
else
bx
//
REPLICATE_H
H0
=
g_i
*
padded_H
+
(
0
if
REPLICATE_H
==
1
else
(
bx
%
REPLICATE_H
)
*
64
)
H1
=
H0
+
H_per_block
tx
=
T
.
get_thread_binding
()
T
.
copy
(
Q
[
b_i
,
s_i
,
H0
:
H1
,
0
:
D
//
2
],
Q_shared_l
)
T
.
copy
(
Q
[
b_i
,
s_i
,
H0
:
H1
,
D
//
2
:
D
],
Q_shared_r
)
T
.
copy
(
Q
[
b_i
,
s_i
,
H0
:
H1
,
D
:],
Q_tail_shared
)
T
.
barrier_arrive
(
bar_q
)
if
tx
<
128
:
T
.
set_max_nreg
(
240
,
1
)
T
.
fill
(
sumexp
,
0
)
T
.
fill
(
m_i
,
-
(
2
**
30
))
# avoid -inf - inf to cause nan
T
.
fill
(
acc_o_l
,
0
)
T
.
barrier_wait
(
bar_q
,
0
)
for
i_i
in
T
.
serial
(
T
.
ceildiv
(
NI
,
2
)):
# Buffer 0
# with sync_at(bar_0_128, 0):
T
.
barrier_wait
(
bar_k_0_ready
[
0
],
(
i_i
&
1
))
T
.
barrier_arrive
(
bar_0_128
)
T
.
barrier_wait
(
bar_0_128
,
0
)
for
h_i
,
bi_i
in
T
.
Parallel
(
H_per_block
,
BI
):
acc_s
[
h_i
,
bi_i
]
=
T
.
if_then_else
(
is_kv_valid_0
[
bi_i
],
0
,
-
T
.
infinity
(
acc_s
.
dtype
)
)
T
.
gemm
(
Q_shared_l
,
KV_shared_0_l
,
acc_s
,
transpose_B
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_r
,
KV_shared_0_r
,
acc_s
,
transpose_B
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_tail_shared
,
K_tail_shared_0
,
acc_s
,
transpose_B
=
True
,
wg_wait
=-
1
,
)
T
.
wait_wgmma
(
0
)
if
i_i
!=
0
:
T
.
barrier_arrive
(
bar_sScale_and_sS_free
)
T
.
barrier_wait
(
bar_sScale_and_sS_free
,
((
i_i
*
2
)
&
1
)
^
1
)
T
.
copy
(
m_i
,
m_i_prev
)
T
.
reduce_max
(
acc_s
,
m_i
,
dim
=
1
,
clear
=
False
)
for
h_i
in
T
.
Parallel
(
H_per_block
):
alpha_local
[
h_i
]
=
T
.
exp2
((
m_i_prev
[
h_i
]
-
m_i
[
h_i
])
*
sm_scale
)
for
h_i
,
bi_i
in
T
.
Parallel
(
H_per_block
,
BI
):
acc_s
[
h_i
,
bi_i
]
=
T
.
exp2
(
acc_s
[
h_i
,
bi_i
]
*
sm_scale
-
m_i
[
h_i
]
*
sm_scale
)
T
.
reduce_sum
(
acc_s
,
sumexp_i
,
dim
=
1
)
# is this a accumulate operator?
for
h_i
in
T
.
Parallel
(
H_per_block
):
sumexp
[
h_i
]
=
sumexp
[
h_i
]
*
alpha_local
[
h_i
]
+
sumexp_i
[
h_i
]
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
//
2
):
acc_o_l
[
h_i
,
d_i
]
*=
alpha_local
[
h_i
]
T
.
copy
(
alpha_local
,
alpha_shared
)
T
.
copy
(
acc_s
,
S_shared
)
T
.
gemm
(
S_shared
,
KV_shared_0_l
,
acc_o_l
)
T
.
barrier_arrive
(
bar_sScale_and_sS_ready
)
T
.
barrier_arrive
(
bar_k_0_free
[
0
])
# Buffer 1
T
.
barrier_wait
(
bar_k_1_ready
[
0
],
(
i_i
&
1
))
T
.
barrier_arrive
(
bar_0_128
)
T
.
barrier_wait
(
bar_0_128
,
1
)
for
h_i
,
bi_i
in
T
.
Parallel
(
H_per_block
,
BI
):
acc_s
[
h_i
,
bi_i
]
=
T
.
if_then_else
(
is_kv_valid_1
[
bi_i
],
0
,
-
T
.
infinity
(
acc_s
.
dtype
)
)
T
.
gemm
(
Q_shared_l
,
KV_shared_1_l
,
acc_s
,
transpose_B
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_shared_r
,
KV_shared_1_r
,
acc_s
,
transpose_B
=
True
,
wg_wait
=-
1
)
T
.
gemm
(
Q_tail_shared
,
K_tail_shared_1
,
acc_s
,
transpose_B
=
True
,
wg_wait
=-
1
,
)
T
.
wait_wgmma
(
0
)
T
.
barrier_arrive
(
bar_sScale_and_sS_free
)
T
.
barrier_wait
(
bar_sScale_and_sS_free
,
((
i_i
*
2
+
1
)
&
1
)
^
1
)
T
.
copy
(
m_i
,
m_i_prev
)
T
.
reduce_max
(
acc_s
,
m_i
,
dim
=
1
,
clear
=
False
)
for
h_i
in
T
.
Parallel
(
H_per_block
):
alpha_local
[
h_i
]
=
T
.
exp2
((
m_i_prev
[
h_i
]
-
m_i
[
h_i
])
*
sm_scale
)
for
h_i
,
bi_i
in
T
.
Parallel
(
H_per_block
,
BI
):
acc_s
[
h_i
,
bi_i
]
=
T
.
exp2
(
acc_s
[
h_i
,
bi_i
]
*
sm_scale
-
m_i
[
h_i
]
*
sm_scale
)
T
.
reduce_sum
(
acc_s
,
sumexp_i
,
dim
=
1
)
# is this a accumulate operator?
for
h_i
in
T
.
Parallel
(
H_per_block
):
sumexp
[
h_i
]
=
sumexp
[
h_i
]
*
alpha_local
[
h_i
]
+
sumexp_i
[
h_i
]
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
//
2
):
acc_o_l
[
h_i
,
d_i
]
*=
alpha_local
[
h_i
]
T
.
copy
(
alpha_local
,
alpha_shared
)
T
.
copy
(
acc_s
,
S_shared
)
T
.
gemm
(
S_shared
,
KV_shared_1_l
,
acc_o_l
)
T
.
barrier_arrive
(
bar_sScale_and_sS_ready
)
T
.
barrier_arrive
(
bar_k_1_free
[
0
])
# Rescale
for
h_i
in
T
.
Parallel
(
H_per_block
):
sum_exp_shared
[
h_i
]
=
sumexp
[
h_i
]
T
.
barrier_arrive
(
bar_final
)
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
//
2
):
acc_o_l
[
h_i
,
d_i
]
/=
sumexp
[
h_i
]
for
h_i
in
T
.
Parallel
(
H_per_block
):
sumexp
[
h_i
]
=
T
.
log2
(
sumexp
[
h_i
])
+
m_i
[
h_i
]
*
sm_scale
T
.
copy
(
acc_o_l
,
O_shared_l
)
T
.
copy
(
O_shared_l
,
Output
[
b_i
,
s_i
,
H0
:
H1
,
0
:
D
//
2
])
elif
tx
>=
128
and
tx
<
256
:
# T.set_max_nreg(168, 1)
T
.
fill
(
acc_o_r
,
0
)
for
i_i
in
T
.
serial
(
T
.
ceildiv
(
NI
,
2
)):
# Buffer 0
T
.
barrier_arrive
(
bar_sScale_and_sS_ready
)
T
.
barrier_wait
(
bar_sScale_and_sS_ready
,
((
i_i
*
2
)
&
1
))
T
.
barrier_arrive
(
bar_1_128
)
T
.
barrier_wait
(
bar_1_128
,
0
)
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
//
2
):
acc_o_r
[
h_i
,
d_i
]
*=
alpha_shared
[
h_i
]
T
.
gemm
(
S_shared
,
KV_shared_0_r
,
acc_o_r
)
T
.
barrier_arrive
(
bar_k_0_free
[
0
])
T
.
barrier_arrive
(
bar_sScale_and_sS_free
)
# Buffer 1
T
.
barrier_arrive
(
bar_sScale_and_sS_ready
)
T
.
barrier_wait
(
bar_sScale_and_sS_ready
,
((
i_i
*
2
+
1
)
&
1
))
T
.
barrier_arrive
(
bar_1_128
)
T
.
barrier_wait
(
bar_1_128
,
1
)
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
//
2
):
acc_o_r
[
h_i
,
d_i
]
*=
alpha_shared
[
h_i
]
T
.
gemm
(
S_shared
,
KV_shared_1_r
,
acc_o_r
)
T
.
barrier_arrive
(
bar_k_1_free
[
0
])
if
i_i
!=
T
.
ceildiv
(
NI
,
2
)
-
1
:
T
.
barrier_arrive
(
bar_sScale_and_sS_free
)
# Rescale
T
.
barrier_wait
(
bar_final
,
0
)
for
h_i
,
d_i
in
T
.
Parallel
(
H_per_block
,
D
//
2
):
acc_o_r
[
h_i
,
d_i
]
/=
sum_exp_shared
[
h_i
]
T
.
copy
(
acc_o_r
,
O_shared_r
)
T
.
copy
(
O_shared_r
,
Output
[
b_i
,
s_i
,
H0
:
H1
,
D
//
2
:
D
])
elif
tx
>=
256
:
# producer
T
.
set_max_nreg
(
80
,
0
)
indices_local
[
0
]
=
0
for
i_i
in
T
.
serial
(
T
.
ceildiv
(
NI
,
2
)):
# Buffer 0
T
.
barrier_wait
(
bar_k_0_free
[
0
],
((
i_i
&
1
)
^
1
))
T
.
barrier_arrive
(
bar_2_128
)
T
.
barrier_wait
(
bar_2_128
,
0
)
for
r
in
T
.
serial
(
4
):
indices_tmp
[
0
]
=
Indices
[
b_i
,
s_i
,
g_i
,
(
i_i
*
2
)
*
BI
+
r
*
16
+
(
tx
-
256
)
//
8
]
is_kv_valid_0
[
r
*
16
+
(
tx
-
256
)
//
8
]
=
indices_tmp
[
0
]
>=
0
if
is_kv_valid_0
[
r
*
16
+
(
tx
-
256
)
//
8
]:
indices_local
[
0
]
=
indices_tmp
[
0
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
# type: ignore
for
u
in
T
.
serial
(
4
):
for
v
in
T
.
vectorized
(
8
):
KV_shared_0_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
=
KV
[
b_i
,
indices_local
[
0
],
g_i
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
KV_shared_0_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
=
KV
[
b_i
,
indices_local
[
0
],
g_i
,
D
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
# type: ignore
for
v
in
T
.
vectorized
(
8
):
K_tail_shared_0
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
b_i
,
indices_local
[
0
],
g_i
,
D
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
T
.
cp_async_barrier_noinc
(
bar_k_0_ready
[
0
])
# Buffer 1
T
.
barrier_wait
(
bar_k_1_free
[
0
],
((
i_i
&
1
)
^
1
))
T
.
barrier_arrive
(
bar_2_128
)
T
.
barrier_wait
(
bar_2_128
,
1
)
for
r
in
T
.
serial
(
4
):
indices_tmp
[
0
]
=
Indices
[
b_i
,
s_i
,
g_i
,
(
i_i
*
2
+
1
)
*
BI
+
r
*
16
+
(
tx
-
256
)
//
8
]
is_kv_valid_1
[
r
*
16
+
(
tx
-
256
)
//
8
]
=
indices_tmp
[
0
]
>=
0
if
is_kv_valid_1
[
r
*
16
+
(
tx
-
256
)
//
8
]:
indices_local
[
0
]
=
indices_tmp
[
0
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
# type: ignore
for
u
in
T
.
serial
(
4
):
for
v
in
T
.
vectorized
(
8
):
KV_shared_1_l
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
=
KV
[
b_i
,
indices_local
[
0
],
g_i
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
KV_shared_1_r
[
r
*
16
+
(
tx
-
256
)
//
8
,
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
=
KV
[
b_i
,
indices_local
[
0
],
g_i
,
D
//
2
+
64
*
u
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
with
T
.
attr
(
"default"
,
"async_scope"
,
1
):
# type: ignore
for
v
in
T
.
vectorized
(
8
):
K_tail_shared_1
[
r
*
16
+
(
tx
-
256
)
//
8
,
(
tx
-
256
)
%
8
*
8
+
v
]
=
KV
[
b_i
,
indices_local
[
0
],
g_i
,
D
+
(
tx
-
256
)
%
8
*
8
+
v
,
]
T
.
cp_async_barrier_noinc
(
bar_k_1_ready
[
0
])
return
main
def
tilelang_sparse_fwd
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
sm_scale
:
float
,
d_v
:
int
=
512
,
)
->
torch
.
Tensor
:
assert
q
.
dim
()
==
3
and
kv
.
dim
()
==
3
and
indices
.
dim
()
==
3
num_heads
=
q
.
shape
[
1
]
dim
=
q
.
shape
[
2
]
tail_dim
=
dim
-
d_v
topk
=
indices
.
shape
[
-
1
]
assert
topk
==
2048
# NOTE(dark): v2 offers better performance than v1
kernel
=
sparse_attention_fwd_kernel_v2
(
num_heads
,
d_v
,
tail_dim
,
topk
,
sm_scale
=
sm_scale
)
return
kernel
(
q
.
unsqueeze
(
0
),
kv
.
unsqueeze
(
0
),
indices
.
unsqueeze
(
0
))
# type: ignore
python/sglang/srt/layers/attention/nsa/topk.py
0 → 100644
View file @
852a49c5
import
torch
from
sglang.srt.utils
import
align
# NOTE(dark): flashmla P requires `params.topk % (2*B_TOPK) == 0`,
# where `B_TOPK=64`. So we align to 128 by default.
_TOPK_ALIGNMENT
=
128
# TODO(dark): maybe this torch_op can support torch.compile
def
_fast_topk_torch
(
input
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
topk
:
int
,
alignment
:
int
)
->
torch
.
Tensor
:
# Fallback to torch.topk
bs
,
max_seq_len
=
input
.
shape
assert
len
(
seq_lens
)
==
bs
# set those out-of-bound input to -inf
padded_max_seq_len
=
align
(
max_seq_len
,
alignment
)
positions
=
torch
.
arange
(
padded_max_seq_len
,
device
=
input
.
device
,
dtype
=
seq_lens
.
dtype
)
positions
=
positions
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
)
mask
=
positions
>=
seq_lens
.
unsqueeze
(
1
)
# NOTE(dark): just return all valid indices as an optimization
if
padded_max_seq_len
<=
topk
:
return
positions
.
masked_fill
(
mask
,
-
1
)
assert
topk
%
alignment
==
0
# in-place operation: mask invalid inputs to -inf
input
=
input
.
masked_fill_
(
mask
[:,
:
max_seq_len
],
float
(
"-inf"
))
result
=
input
.
topk
(
topk
,
dim
=-
1
,
sorted
=
True
)
return
result
.
indices
.
masked_fill_
(
mask
[:,
:
topk
],
-
1
)
def
fast_topk_impl
(
input
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
topk
:
int
,
alignment
:
int
=
_TOPK_ALIGNMENT
,
)
->
torch
.
Tensor
:
return
_fast_topk_torch
(
input
,
seq_lens
,
topk
,
alignment
)
def
fast_topk_transform_fused_cuda
(
input
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
topk
:
int
,
dst_page_table
:
torch
.
Tensor
,
src_page_table
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
alignment
:
int
=
_TOPK_ALIGNMENT
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.attention.nsa.cuda
import
fast_topk_transform
assert
topk
==
2048
and
topk
%
alignment
==
0
return
fast_topk_transform
(
score
=
input
,
lengths
=
seq_lens
,
dst_page_table
=
dst_page_table
,
src_page_table
=
src_page_table
,
cu_seqlens
=
cu_seqlens_q
,
)
python/sglang/srt/layers/attention/nsa/transform_index.py
0 → 100644
View file @
852a49c5
from
typing
import
List
,
Optional
import
torch
import
triton
import
triton.language
as
tl
def
transform_index_page_table_prefill
(
**
kwargs
):
return
transform_index_page_table_prefill_ref
(
**
kwargs
)
def
transform_index_page_table_decode
(
**
kwargs
):
return
transform_index_page_table_decode_ref
(
**
kwargs
)
@
triton
.
jit
def
transform_index_page_table_decode_kernel
(
page_table_ptr
:
torch
.
Tensor
,
topk_indices_ptr
:
torch
.
Tensor
,
result_ptr
:
torch
.
Tensor
,
page_size
:
tl
.
constexpr
,
max_seqlen_k
:
tl
.
constexpr
,
):
TOPK
:
tl
.
constexpr
=
2048
req_id
=
tl
.
program_id
(
0
)
page_table_ptr
=
page_table_ptr
+
req_id
*
max_seqlen_k
topk_indices_ptr
=
topk_indices_ptr
+
req_id
*
TOPK
result_ptr
=
result_ptr
+
req_id
*
TOPK
offset
=
tl
.
arange
(
0
,
TOPK
)
# topk should be 2048
loaded_topk_indices
=
tl
.
load
(
topk_indices_ptr
+
offset
)
mask
=
loaded_topk_indices
>=
0
loaded_kv_indices
=
tl
.
load
(
page_table_ptr
+
loaded_topk_indices
,
mask
=
mask
)
tl
.
store
(
result_ptr
+
offset
,
loaded_kv_indices
,
mask
=
mask
)
tl
.
store
(
result_ptr
+
offset
,
-
1
,
mask
=~
mask
)
def
transform_index_page_table_decode_fast
(
page_table
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
result
:
Optional
[
torch
.
Tensor
]
=
None
,
page_size
:
int
=
1
,
)
->
torch
.
Tensor
:
"""
Transform the page table according to topk indices for sparse topk attention.
Args:
page_table: [qo_len, max_seqlen_k], the original page table
topk_indices: [qo_len, topk], the topk indices for each query position
Returns:
transformed_page_table: [qo_len, topk], the transformed page table
For out-of-bound indices in topk_indices, this should be filled with -1.
"""
assert
page_size
==
1
assert
page_table
.
shape
[
0
]
==
topk_indices
.
shape
[
0
]
assert
topk_indices
.
shape
[
1
]
==
2048
qo_len
=
topk_indices
.
shape
[
0
]
max_seqlen_k
=
page_table
.
shape
[
1
]
if
result
is
None
:
result
=
torch
.
empty_like
(
topk_indices
,
dtype
=
torch
.
int32
)
# Launch triton kernel
grid
=
(
qo_len
,)
transform_index_page_table_decode_kernel
[
grid
](
page_table
,
topk_indices
,
result
,
page_size
,
max_seqlen_k
=
max_seqlen_k
,
)
return
result
def
transform_index_page_table_prefill_fast
(
page_table
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
extend_lens_cpu
:
List
[
int
],
page_size
:
int
=
1
,
)
->
torch
.
Tensor
:
# TODO(baizhou): can be implemented with another triton kernel
assert
page_size
==
1
result
=
torch
.
empty_like
(
topk_indices
,
dtype
=
torch
.
int32
)
assert
len
(
extend_lens_cpu
)
==
page_table
.
shape
[
0
]
offset
=
0
for
i
,
l
in
enumerate
(
extend_lens_cpu
):
transform_index_page_table_decode_fast
(
page_table
[
i
].
unsqueeze
(
0
).
expand
(
l
,
-
1
),
topk_indices
[
offset
:
offset
+
l
],
result
=
result
[
offset
:
offset
+
l
],
)
offset
+=
l
assert
offset
==
topk_indices
.
shape
[
0
]
return
result
def
transform_index_page_table_decode_ref
(
page_table
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
result
:
Optional
[
torch
.
Tensor
]
=
None
,
page_size
:
int
=
1
,
)
->
torch
.
Tensor
:
assert
page_size
==
1
assert
page_table
.
shape
[
0
]
==
topk_indices
.
shape
[
0
]
if
result
is
None
:
result
=
torch
.
empty_like
(
topk_indices
,
dtype
=
torch
.
int32
)
assert
result
.
shape
==
topk_indices
.
shape
torch
.
gather
(
page_table
,
dim
=
1
,
index
=
topk_indices
.
clamp
(
min
=
0
),
out
=
result
,
)
result
[
topk_indices
<
0
]
=
-
1
return
result
def
transform_index_page_table_prefill_ref
(
page_table
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
extend_lens_cpu
:
List
[
int
],
page_size
:
int
=
1
,
)
->
torch
.
Tensor
:
assert
page_size
==
1
result
=
torch
.
empty_like
(
topk_indices
,
dtype
=
torch
.
int32
)
assert
len
(
extend_lens_cpu
)
==
page_table
.
shape
[
0
]
offset
=
0
for
i
,
l
in
enumerate
(
extend_lens_cpu
):
transform_index_page_table_decode_ref
(
page_table
[
i
].
unsqueeze
(
0
).
expand
(
l
,
-
1
),
topk_indices
[
offset
:
offset
+
l
],
result
=
result
[
offset
:
offset
+
l
],
)
offset
+=
l
assert
offset
==
topk_indices
.
shape
[
0
]
return
result
if
__name__
==
"__main__"
:
bs
,
topk
,
max_seqlen
=
10
,
2048
,
3000
page_table
=
torch
.
randint
(
0
,
100
,
(
bs
,
max_seqlen
),
device
=
"cuda"
)
topk_indices
=
torch
.
full
((
bs
,
topk
),
-
1
,
device
=
"cuda"
)
topk_indices
[:,
:
1600
]
=
torch
.
arange
(
1600
).
unsqueeze
(
0
).
repeat
(
bs
,
1
)
ref_result
=
transform_index_page_table_decode_ref
(
page_table
,
topk_indices
)
result
=
transform_index_page_table_decode_fast
(
page_table
,
topk_indices
)
assert
torch
.
all
(
result
==
ref_result
)
print
(
"Passed"
)
python/sglang/srt/layers/attention/nsa/unit_test/get_logits_ut.py
0 → 100644
View file @
852a49c5
import
torch
import
torch.nn
as
nn
class
DummyModel
(
nn
.
Module
):
def
__init__
(
self
,
d_in
=
2048
,
n_heads
=
128
,
softmax_scale
=
0.5
):
super
().
__init__
()
self
.
weights_proj
=
nn
.
Linear
(
d_in
,
1024
)
self
.
n_heads
=
n_heads
self
.
softmax_scale
=
softmax_scale
def
_get_logits_head_gate_orig
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
=
self
.
weights_proj
(
x
)
weights
=
weights
*
self
.
n_heads
**-
0.5
q_scale
=
q_scale
.
unsqueeze
(
1
)
# (B,1,1)
weights
=
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
return
weights
def
_get_logits_head_gate_opt
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
=
self
.
weights_proj
(
x
)
q_scale
=
q_scale
.
unsqueeze
(
1
)
# (B,1,1)
scale_const
=
self
.
n_heads
**-
0.5
*
q_scale
*
self
.
softmax_scale
# (B,1,1)
weights
=
weights
.
unsqueeze
(
-
1
)
*
scale_const
# (B,1024,1)
return
weights
def
main
():
torch
.
manual_seed
(
0
)
model
=
DummyModel
(
d_in
=
2048
,
n_heads
=
128
,
softmax_scale
=
0.5
)
x
=
torch
.
randn
(
128
,
2048
)
# batch=128, d_in=2048
q_scale
=
torch
.
randn
(
128
,
1
)
import
time
start
=
time
.
time
()
for
_
in
range
(
1000
):
out_orig
=
model
.
_get_logits_head_gate_orig
(
x
,
q_scale
)
print
(
"Original version time:"
,
time
.
time
()
-
start
)
start
=
time
.
time
()
for
_
in
range
(
1000
):
out_opt
=
model
.
_get_logits_head_gate_opt
(
x
,
q_scale
)
print
(
"Optimized version time:"
,
time
.
time
()
-
start
)
print
(
"Difference:"
,
(
out_orig
-
out_opt
).
abs
().
max
().
item
())
assert
torch
.
allclose
(
out_orig
,
out_opt
),
"Mismatch between original and optimized"
if
__name__
==
"__main__"
:
main
()
"""
Original version time: 0.49235057830810547
Optimized version time: 0.4087331295013428
Difference: 1.4901161193847656e-08
"""
python/sglang/srt/layers/attention/nsa/utils.py
0 → 100644
View file @
852a49c5
# temp NSA debugging environ
from
sglang.srt.utils
import
get_bool_env_var
NSA_USE_REAL_INDEXER
=
get_bool_env_var
(
"SGLANG_NSA_USE_REAL_INDEXER"
,
"true"
)
NSA_DUAL_STREAM
=
get_bool_env_var
(
"SGLANG_NSA_DUAL_STREAM"
,
"true"
)
NSA_FUSE_TOPK
=
get_bool_env_var
(
"SGLANG_NSA_FUSE_TOPK"
,
"true"
)
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
=
get_bool_env_var
(
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8"
,
"true"
)
NSA_KV_CACHE_STORE_FP8
=
get_bool_env_var
(
"SGLANG_NSA_KV_CACHE_STORE_FP8"
,
"false"
)
NSA_QUANT_K_CACHE_FAST
=
get_bool_env_var
(
"SGLANG_NSA_QUANT_K_CACHE_FAST"
,
"false"
)
NSA_DEQUANT_K_CACHE_FAST
=
get_bool_env_var
(
"SGLANG_NSA_DEQUANT_K_CACHE_FAST"
,
"false"
)
def
_print_bool_env_vars
():
msg
=
""
for
k
,
v
in
globals
().
items
():
if
k
.
startswith
(
"NSA_"
)
and
isinstance
(
v
,
bool
):
msg
+=
f
"
{
k
}
=
{
v
}
"
print
(
msg
,
flush
=
True
)
_print_bool_env_vars
()
if
not
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
:
assert
not
NSA_KV_CACHE_STORE_FP8
def
compute_nsa_seqlens
(
original_seq_lens
,
nsa_index_topk
:
int
):
return
original_seq_lens
.
clamp
(
max
=
nsa_index_topk
)
python/sglang/srt/layers/attention/nsa_backend.py
0 → 100644
View file @
852a49c5
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
TypeAlias
,
Union
,
override
,
)
import
torch
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sglang.srt.configs.model_config
import
get_nsa_index_topk
,
is_deepseek_nsa
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.nsa.dequant_k_cache
import
dequantize_k_cache
from
sglang.srt.layers.attention.nsa.nsa_indexer
import
BaseIndexerMetadata
from
sglang.srt.layers.attention.nsa.quant_k_cache
import
quantize_k_cache
from
sglang.srt.layers.attention.nsa.topk
import
(
fast_topk_impl
,
fast_topk_transform_fused_cuda
,
)
from
sglang.srt.layers.attention.nsa.transform_index
import
(
transform_index_page_table_decode
,
transform_index_page_table_prefill
,
)
from
sglang.srt.layers.attention.nsa.utils
import
(
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
,
NSA_FUSE_TOPK
,
NSA_KV_CACHE_STORE_FP8
,
compute_nsa_seqlens
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.two_batch_overlap
import
global_server_args_dict
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
@
dataclass
(
frozen
=
True
)
class
NSAFlashMLAMetadata
:
"""Metadata only needed by FlashMLA"""
flashmla_metadata
:
torch
.
Tensor
num_splits
:
torch
.
Tensor
def
slice
(
self
,
sli
):
return
NSAFlashMLAMetadata
(
flashmla_metadata
=
self
.
flashmla_metadata
,
num_splits
=
self
.
num_splits
[
sli
],
)
def
copy_
(
self
,
other
:
"NSAFlashMLAMetadata"
):
self
.
flashmla_metadata
.
copy_
(
other
.
flashmla_metadata
)
self
.
num_splits
.
copy_
(
other
.
num_splits
)
@
dataclass
(
frozen
=
True
)
class
NSAMetadata
:
page_size
:
int
# Sequence lengths for the forward batch
cache_seqlens_int32
:
torch
.
Tensor
# Maximum sequence length for query
max_seq_len_q
:
int
# Maximum sequence length for key
max_seq_len_k
:
int
# Cumulative sequence lengths for query
cu_seqlens_q
:
torch
.
Tensor
# Cumulative sequence lengths for key
cu_seqlens_k
:
torch
.
Tensor
# Page table, the index of KV Cache Tables/Blocks
# this table is always with page_size = 1
page_table_1
:
torch
.
Tensor
# NOTE(dark): This will property be used in:
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
real_page_table
:
torch
.
Tensor
# NSA metadata (nsa prefill are expanded)
nsa_cache_seqlens_int32
:
torch
.
Tensor
# this seqlens is clipped to `topk`
nsa_cu_seqlens_q
:
torch
.
Tensor
# must be arange(0, len(nsa_cu_seqlens_k))
nsa_cu_seqlens_k
:
torch
.
Tensor
# cumsum of `nsa_cache_seqlens_int32`
nsa_extend_seq_lens_list
:
List
[
int
]
nsa_seqlens_expanded
:
torch
.
Tensor
# expanded, unclipped `seqlens`
nsa_max_seqlen_q
:
Literal
[
1
]
=
1
# always 1 for decode, variable for extend
flashmla_metadata
:
Optional
[
NSAFlashMLAMetadata
]
=
None
@
dataclass
(
frozen
=
True
)
class
NSAIndexerMetadata
(
BaseIndexerMetadata
):
attn_metadata
:
NSAMetadata
@
override
def
get_seqlens_int32
(
self
)
->
torch
.
Tensor
:
return
self
.
attn_metadata
.
cache_seqlens_int32
@
override
def
get_page_table_64
(
self
)
->
torch
.
Tensor
:
return
self
.
attn_metadata
.
real_page_table
@
override
def
get_seqlens_expanded
(
self
)
->
torch
.
Tensor
:
return
self
.
attn_metadata
.
nsa_seqlens_expanded
@
override
def
topk_transform
(
self
,
logits
:
torch
.
Tensor
,
topk
:
int
,
)
->
torch
.
Tensor
:
if
not
NSA_FUSE_TOPK
:
return
fast_topk_impl
(
logits
,
self
.
get_seqlens_expanded
(),
topk
)
# NOTE(dark): if fused, we return a transformed page table directly
dst_page_table
=
torch
.
empty
(
(
logits
.
shape
[
0
],
topk
),
dtype
=
torch
.
int32
,
device
=
logits
.
device
)
fast_topk_transform_fused_cuda
(
input
=
logits
,
seq_lens
=
self
.
get_seqlens_expanded
(),
topk
=
topk
,
dst_page_table
=
dst_page_table
,
src_page_table
=
self
.
attn_metadata
.
page_table_1
,
cu_seqlens_q
=
self
.
attn_metadata
.
cu_seqlens_q
,
)
return
dst_page_table
def
compute_cu_seqlens
(
seqlens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
seqlens
.
dtype
==
torch
.
int32
and
seqlens
.
is_cuda
return
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
_NSA_IMPL_T
:
TypeAlias
=
Literal
[
"flashmla_prefill"
,
"flashmla_decode"
,
"fa3"
,
"tilelang"
]
NSA_PREFILL_IMPL
:
_NSA_IMPL_T
NSA_DECODE_IMPL
:
_NSA_IMPL_T
class
NativeSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
self
.
forward_metadata
:
NSAMetadata
self
.
device
=
model_runner
.
device
assert
isinstance
(
model_runner
.
page_size
,
int
)
self
.
real_page_size
=
model_runner
.
page_size
self
.
num_splits
=
(
1
if
model_runner
.
server_args
.
enable_deterministic_inference
else
0
)
self
.
use_nsa
=
is_deepseek_nsa
(
model_runner
.
model_config
.
hf_config
)
assert
self
.
use_nsa
,
"NSA backend only supports DeepSeek NSA"
self
.
nsa_index_topk
=
get_nsa_index_topk
(
model_runner
.
model_config
.
hf_config
)
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
num_q_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
kv_cache_dim
=
model_runner
.
token_to_kv_pool
.
kv_cache_dim
assert
model_runner
.
req_to_token_pool
is
not
None
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
global
NSA_PREFILL_IMPL
,
NSA_DECODE_IMPL
NSA_PREFILL_IMPL
=
model_runner
.
server_args
.
nsa_prefill
NSA_DECODE_IMPL
=
model_runner
.
server_args
.
nsa_decode
self
.
_arange_buf
=
torch
.
arange
(
16384
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
def
get_device_int32_arange
(
self
,
l
:
int
)
->
torch
.
Tensor
:
if
l
>
len
(
self
.
_arange_buf
):
next_pow_of_2
=
1
<<
(
l
-
1
).
bit_length
()
self
.
_arange_buf
=
torch
.
arange
(
next_pow_of_2
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
return
self
.
_arange_buf
[:
l
]
def
_transform_table_1_to_real
(
self
,
page_table
:
torch
.
Tensor
)
->
torch
.
Tensor
:
page_size
=
self
.
real_page_size
if
page_size
==
1
:
return
page_table
max_seqlen_k
=
page_table
.
shape
[
1
]
strided_indices
=
torch
.
arange
(
0
,
max_seqlen_k
,
page_size
,
device
=
page_table
.
device
,
dtype
=
torch
.
int32
)
return
page_table
[:,
strided_indices
]
//
page_size
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
batch_size
=
forward_batch
.
batch_size
device
=
forward_batch
.
seq_lens
.
device
assert
(
forward_batch
.
spec_info
is
None
),
"Spec decoding is not supported for NSA backend now"
cache_seqlens_int32
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
cu_seqlens_k
=
compute_cu_seqlens
(
cache_seqlens_int32
)
assert
forward_batch
.
seq_lens_cpu
is
not
None
max_seqlen_k
=
int
(
forward_batch
.
seq_lens_cpu
.
max
().
item
())
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
max_seqlen_k
]
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
extend_seq_lens_cpu
=
[
1
]
*
batch_size
max_seqlen_q
=
1
cu_seqlens_q
=
self
.
get_device_int32_arange
(
batch_size
+
1
)
seqlens_expanded
=
cache_seqlens_int32
elif
forward_batch
.
forward_mode
.
is_extend
():
assert
(
forward_batch
.
extend_seq_lens_cpu
is
not
None
and
forward_batch
.
extend_seq_lens
is
not
None
and
forward_batch
.
extend_prefix_lens_cpu
is
not
None
),
"All of them must not be None"
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
assert
forward_batch
.
extend_seq_lens
is
not
None
if
any
(
forward_batch
.
extend_prefix_lens_cpu
):
max_seqlen_q
=
max
(
extend_seq_lens_cpu
)
cu_seqlens_q
=
compute_cu_seqlens
(
forward_batch
.
extend_seq_lens
.
to
(
torch
.
int32
)
)
else
:
max_seqlen_q
=
max_seqlen_k
cu_seqlens_q
=
cu_seqlens_k
seqlens_expanded
=
torch
.
cat
(
[
torch
.
arange
(
kv_len
-
qo_len
+
1
,
kv_len
+
1
,
dtype
=
torch
.
int32
,
device
=
device
,
)
for
qo_len
,
kv_len
in
zip
(
forward_batch
.
extend_seq_lens_cpu
,
forward_batch
.
seq_lens_cpu
.
tolist
(),
strict
=
True
,
)
]
)
else
:
assert
False
,
f
"Unsupported
{
forward_batch
.
forward_mode
=
}
"
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
nsa_cache_seqlens_int32
=
compute_nsa_seqlens
(
original_seq_lens
=
seqlens_expanded
,
nsa_index_topk
=
self
.
nsa_index_topk
,
)
nsa_cu_seqlens_k
=
compute_cu_seqlens
(
nsa_cache_seqlens_int32
)
nsa_cu_seqlens_q
=
self
.
get_device_int32_arange
(
len
(
nsa_cu_seqlens_k
))
metadata
=
NSAMetadata
(
page_size
=
self
.
real_page_size
,
cache_seqlens_int32
=
cache_seqlens_int32
,
max_seq_len_q
=
max_seqlen_q
,
max_seq_len_k
=
max_seqlen_k
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
page_table_1
=
page_table
,
flashmla_metadata
=
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
# TODO handle MTP which is not 1
)
if
NSA_DECODE_IMPL
==
"flashmla_decode"
else
None
),
nsa_cache_seqlens_int32
=
nsa_cache_seqlens_int32
,
nsa_cu_seqlens_q
=
nsa_cu_seqlens_q
,
nsa_cu_seqlens_k
=
nsa_cu_seqlens_k
,
nsa_seqlens_expanded
=
seqlens_expanded
,
nsa_extend_seq_lens_list
=
extend_seq_lens_cpu
,
real_page_table
=
self
.
_transform_table_1_to_real
(
page_table
),
)
self
.
forward_metadata
=
metadata
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self
.
decode_cuda_graph_metadata
:
Dict
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
# fake page_table for sparse_prefill
"page_table"
:
torch
.
zeros
(
max_bs
,
self
.
max_context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"flashmla_metadata"
:
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
torch
.
ones
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
seq_len_q
=
1
,
# TODO handle MTP which is not 1
)
if
NSA_DECODE_IMPL
==
"flashmla_decode"
else
None
),
}
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
"""Initialize forward metadata for capturing CUDA graph."""
assert
forward_mode
.
is_decode_or_idle
(),
"Only support decode for now"
assert
(
spec_info
is
None
),
"Speculative decoding is not supported for NSA backend now"
# Normal Decode
# Get sequence information
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
cu_seqlens_k
=
compute_cu_seqlens
(
cache_seqlens_int32
)
# Use max context length for seq_len_k
page_table_1
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
max_seq_len_k
=
page_table_1
.
shape
[
1
]
# Precompute page table
# Precompute cumulative sequence lengths
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
][:
bs
+
1
]
nsa_cache_seqlens_int32
=
compute_nsa_seqlens
(
cache_seqlens_int32
,
nsa_index_topk
=
self
.
nsa_index_topk
)
nsa_cu_seqlens_k
=
compute_cu_seqlens
(
nsa_cache_seqlens_int32
)
nsa_cu_seqlens_q
=
self
.
get_device_int32_arange
(
len
(
nsa_cu_seqlens_k
))
real_page_table
=
self
.
_transform_table_1_to_real
(
page_table_1
)
if
NSA_DECODE_IMPL
==
"flashmla_decode"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
bs
+
1
))
flashmla_metadata
.
copy_
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
# TODO handle MTP which is not 1
)
)
else
:
flashmla_metadata
=
None
metadata
=
NSAMetadata
(
page_size
=
self
.
real_page_size
,
cache_seqlens_int32
=
cache_seqlens_int32
,
max_seq_len_q
=
1
,
max_seq_len_k
=
max_seq_len_k
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
page_table_1
=
page_table_1
,
flashmla_metadata
=
flashmla_metadata
,
nsa_cache_seqlens_int32
=
nsa_cache_seqlens_int32
,
nsa_cu_seqlens_q
=
nsa_cu_seqlens_q
,
nsa_cu_seqlens_k
=
nsa_cu_seqlens_k
,
nsa_seqlens_expanded
=
cache_seqlens_int32
,
real_page_table
=
real_page_table
,
nsa_extend_seq_lens_list
=
[
1
]
*
bs
,
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert
seq_lens_cpu
is
not
None
assert
forward_mode
.
is_decode_or_idle
(),
"Only support decode for now"
assert
(
spec_info
is
None
),
"Speculative decoding is not supported for NSA backend now"
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
# Normal Decode
metadata
:
NSAMetadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
int
(
seq_lens_cpu
.
max
().
item
())
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
.
copy_
(
cache_seqlens
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
cache_seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
page_indices
=
self
.
req_to_token
[
req_pool_indices
,
:
max_len
]
metadata
.
page_table_1
[:,
:
max_len
].
copy_
(
page_indices
)
assert
(
metadata
.
nsa_cache_seqlens_int32
is
not
None
and
metadata
.
nsa_cu_seqlens_k
is
not
None
and
self
.
nsa_index_topk
is
not
None
)
nsa_cache_seqlens
=
compute_nsa_seqlens
(
cache_seqlens
,
self
.
nsa_index_topk
)
metadata
.
nsa_cache_seqlens_int32
.
copy_
(
nsa_cache_seqlens
)
metadata
.
nsa_cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
nsa_cache_seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
assert
self
.
real_page_size
==
metadata
.
page_size
if
self
.
real_page_size
>
1
:
real_table
=
self
.
_transform_table_1_to_real
(
page_indices
)
new_len
=
real_table
.
shape
[
1
]
metadata
.
real_page_table
[:,
:
new_len
].
copy_
(
real_table
)
else
:
assert
metadata
.
real_page_table
is
metadata
.
page_table_1
if
NSA_DECODE_IMPL
==
"flashmla_decode"
:
metadata
.
flashmla_metadata
.
copy_
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens
,
seq_len_q
=
1
,
# TODO handle MTP which is not 1
)
)
self
.
forward_metadata
=
metadata
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
(
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
),
"NSA backend doesn't support speculative decoding"
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
# type: ignore
layer
,
cache_loc
,
k
,
k_rope
,
)
metadata
=
self
.
forward_metadata
causal
=
not
layer
.
is_cross_attention
assert
causal
,
"NSA is causal only"
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
# Do absorbed multi-latent attention
assert
q_rope
is
not
None
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
# when store in fp8 and compute in fp8, no need to convert dtype
if
not
(
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
and
NSA_KV_CACHE_STORE_FP8
):
kv_cache
=
kv_cache
.
to
(
q
.
dtype
)
if
q_rope
is
not
None
:
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
else
:
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
# NOTE(dark): here, we use page size = 1
if
NSA_FUSE_TOPK
:
page_table_1
=
topk_indices
else
:
assert
metadata
.
nsa_extend_seq_lens_list
is
not
None
page_table_1
=
transform_index_page_table_prefill
(
page_table
=
metadata
.
page_table_1
,
topk_indices
=
topk_indices
,
extend_lens_cpu
=
metadata
.
nsa_extend_seq_lens_list
,
page_size
=
1
,
)
if
NSA_PREFILL_IMPL
==
"tilelang"
:
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
(
tilelang_sparse_fwd
,
)
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_tilelang
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
page_table_1
=
page_table_1
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_PREFILL_IMPL
==
"flashmla_prefill"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_prefill
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
page_table_1
=
page_table_1
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_PREFILL_IMPL
==
"flashmla_decode"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_decode
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
# TODO optimize args
layer
=
layer
,
forward_batch
=
forward_batch
,
metadata
=
metadata
,
topk_indices
=
topk_indices
,
block_table
=
metadata
.
real_page_table
,
)
elif
NSA_PREFILL_IMPL
==
"fa3"
:
return
self
.
_forward_fa3
(
q_rope
=
q_rope
,
kv_cache
=
kv_cache
,
v_head_dim
=
layer
.
v_head_dim
,
q_nope
=
q_nope
,
page_table
=
page_table_1
,
cache_seqlens
=
metadata
.
nsa_cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
nsa_cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
nsa_cu_seqlens_k
,
max_seqlen_q
=
metadata
.
nsa_max_seqlen_q
,
sm_scale
=
layer
.
scaling
,
logit_cap
=
layer
.
logit_cap
,
page_size
=
1
,
)
else
:
raise
ValueError
(
f
"Unsupported
{
NSA_PREFILL_IMPL
=
}
"
)
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
# type: ignore
layer
,
cache_loc
,
k
,
k_rope
,
)
metadata
=
self
.
forward_metadata
causal
=
not
layer
.
is_cross_attention
assert
causal
,
"NSA is causal only"
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
if
q_rope
is
not
None
:
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
else
:
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
if
NSA_FUSE_TOPK
:
page_table_1
=
topk_indices
else
:
page_table_1
=
transform_index_page_table_decode
(
page_table
=
metadata
.
page_table_1
,
topk_indices
=
topk_indices
,
page_size
=
1
,
)
if
NSA_DECODE_IMPL
==
"flashmla_prefill"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_prefill
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
page_table_1
=
page_table_1
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_DECODE_IMPL
==
"flashmla_decode"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_decode
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
# TODO optimize args
layer
=
layer
,
forward_batch
=
forward_batch
,
metadata
=
metadata
,
topk_indices
=
topk_indices
,
block_table
=
metadata
.
real_page_table
,
)
elif
NSA_DECODE_IMPL
==
"tilelang"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_tilelang
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
page_table_1
=
page_table_1
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_DECODE_IMPL
==
"fa3"
:
return
self
.
_forward_fa3
(
q_rope
=
q_rope
,
kv_cache
=
kv_cache
,
v_head_dim
=
layer
.
v_head_dim
,
q_nope
=
q_nope
,
page_table
=
page_table_1
,
cache_seqlens
=
metadata
.
nsa_cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
nsa_cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
nsa_cu_seqlens_k
,
max_seqlen_q
=
metadata
.
nsa_max_seqlen_q
,
sm_scale
=
layer
.
scaling
,
logit_cap
=
layer
.
logit_cap
,
page_size
=
1
,
)
else
:
assert
False
,
f
"Unsupported
{
NSA_DECODE_IMPL
=
}
"
def
_forward_fa3
(
self
,
q_rope
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
v_head_dim
:
int
,
q_nope
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
sm_scale
:
float
,
logit_cap
:
float
,
page_size
:
int
,
)
->
torch
.
Tensor
:
k_rope_cache
=
kv_cache
[:,
:,
v_head_dim
:]
c_kv_cache
=
kv_cache
[:,
:,
:
v_head_dim
]
qk_rope_dim
=
k_rope_cache
.
shape
[
-
1
]
k_rope_cache
=
k_rope_cache
.
view
(
-
1
,
page_size
,
1
,
qk_rope_dim
)
c_kv_cache
=
c_kv_cache
.
view
(
-
1
,
page_size
,
1
,
v_head_dim
)
o
=
flash_attn_with_kvcache
(
q
=
q_rope
,
k_cache
=
k_rope_cache
,
v_cache
=
c_kv_cache
,
qv
=
q_nope
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
sm_scale
,
causal
=
True
,
softcap
=
logit_cap
,
return_softmax_lse
=
False
,
num_splits
=
self
.
num_splits
,
)
return
o
# type: ignore
def
_forward_flashmla_prefill
(
self
,
q_all
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
v_head_dim
:
int
,
page_table_1
:
torch
.
Tensor
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
from
flash_mla
import
flash_mla_sparse_fwd
o
,
_
,
_
=
flash_mla_sparse_fwd
(
q
=
q_all
,
kv
=
kv_cache
,
indices
=
page_table_1
.
unsqueeze
(
1
),
sm_scale
=
sm_scale
,
d_v
=
v_head_dim
,
)
return
o
def
_forward_flashmla_decode
(
self
,
q_all
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
v_head_dim
:
int
,
sm_scale
:
float
,
layer
,
forward_batch
:
ForwardBatch
,
metadata
:
NSAMetadata
,
topk_indices
,
block_table
,
)
->
torch
.
Tensor
:
from
flash_mla
import
flash_mla_with_kvcache
cache_seqlens
=
metadata
.
nsa_cache_seqlens_int32
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
q_all
=
q_all
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
kv_cache
=
kv_cache
.
view
(
-
1
,
self
.
real_page_size
,
1
,
self
.
kv_cache_dim
)
assert
self
.
real_page_size
==
64
,
"only page size 64 is supported"
if
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
and
not
NSA_KV_CACHE_STORE_FP8
:
# inefficiently quantize the whole cache
kv_cache
=
quantize_k_cache
(
kv_cache
)
o
,
_
=
flash_mla_with_kvcache
(
q
=
q_all
,
k_cache
=
kv_cache
,
cache_seqlens
=
cache_seqlens
,
head_dim_v
=
v_head_dim
,
tile_scheduler_metadata
=
metadata
.
flashmla_metadata
.
flashmla_metadata
,
num_splits
=
metadata
.
flashmla_metadata
.
num_splits
,
softmax_scale
=
sm_scale
,
# TODO improve
indices
=
_compute_indices_in_kvcache
(
block_table
=
block_table
,
topk_indices
=
topk_indices
.
to
(
torch
.
int32
),
page_size
=
self
.
real_page_size
,
nsa_index_topk
=
self
.
nsa_index_topk
,
),
# doc says it is not used, but if pass in None then error
block_table
=
torch
.
empty
(
(
q_all
.
shape
[
0
],
0
),
dtype
=
torch
.
int32
,
device
=
q_all
.
device
),
is_fp8_kvcache
=
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
,
)
# TODO shape correct?
return
o
def
_forward_tilelang
(
self
,
q_all
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
v_head_dim
:
int
,
page_table_1
:
torch
.
Tensor
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
tilelang_sparse_fwd
return
tilelang_sparse_fwd
(
q
=
q_all
,
kv
=
kv_cache
,
indices
=
page_table_1
.
unsqueeze
(
1
),
sm_scale
=
sm_scale
,
d_v
=
v_head_dim
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for sequence length in CUDA graph."""
return
1
def
get_indexer_metadata
(
self
,
layer_id
:
int
,
forward_batch
:
ForwardBatch
)
->
NSAIndexerMetadata
:
return
NSAIndexerMetadata
(
attn_metadata
=
self
.
forward_metadata
)
def
_compute_flashmla_metadata
(
self
,
cache_seqlens
:
torch
.
Tensor
,
seq_len_q
:
int
):
from
flash_mla
import
get_mla_metadata
flashmla_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
=
cache_seqlens
,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k
=
seq_len_q
*
self
.
num_q_heads
//
1
,
num_heads_k
=
1
,
num_heads_q
=
self
.
num_q_heads
,
is_fp8_kvcache
=
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
,
topk
=
self
.
nsa_index_topk
,
)
return
NSAFlashMLAMetadata
(
flashmla_metadata
=
flashmla_metadata
,
num_splits
=
num_splits
,
)
# TODO speedup
def
_compute_indices_in_kvcache
(
block_table
,
topk_indices
,
page_size
,
nsa_index_topk
):
topk_indices_safe
=
topk_indices
.
masked_fill
(
topk_indices
==
-
1
,
0
)
idx0
=
torch
.
arange
(
block_table
.
size
(
0
),
device
=
topk_indices_safe
.
device
).
unsqueeze
(
1
)
block_idx
=
block_table
[
idx0
,
topk_indices_safe
//
page_size
]
offset
=
topk_indices_safe
%
page_size
indices_in_kvcache
=
block_idx
*
page_size
+
offset
# the kernel requires invalid entry to be -1
assert
indices_in_kvcache
.
shape
==
topk_indices
.
shape
indices_in_kvcache
[
topk_indices
==
-
1
]
=
-
1
# return: (batch_size, seqlen_q_ori, topk)
indices_in_kvcache
=
indices_in_kvcache
[:,
None
,
:]
indices_in_kvcache
=
torch
.
nn
.
functional
.
pad
(
indices_in_kvcache
,
(
0
,
nsa_index_topk
-
indices_in_kvcache
.
shape
[
-
1
]),
"constant"
,
-
1
,
)
assert
indices_in_kvcache
.
shape
[
-
1
]
==
nsa_index_topk
return
indices_in_kvcache
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment