Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7993ed8d
Commit
7993ed8d
authored
Oct 03, 2025
by
maxiao1
Browse files
适配deepseekv3.2
parent
443a1b4a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
403 additions
and
79 deletions
+403
-79
python/sglang/srt/layers/attention/native_mla.py
python/sglang/srt/layers/attention/native_mla.py
+121
-0
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
+135
-0
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+51
-24
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
+51
-12
python/sglang/srt/layers/attention/nsa/transform_index.py
python/sglang/srt/layers/attention/nsa/transform_index.py
+1
-1
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+25
-26
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+15
-15
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+4
-1
No files found.
python/sglang/srt/layers/attention/native_mla.py
0 → 100644
View file @
7993ed8d
import
math
from
typing
import
Optional
,
Tuple
,
List
import
torch
def
cdiv
(
x
:
int
,
y
:
int
):
return
(
x
+
y
-
1
)
//
y
def
native_mla_sparse_fwd
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
sm_scale
:
float
,
d_v
:
int
=
512
,)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
s_q
,
_
,
d_qk
=
q
.
size
()
s_kv
=
kv
.
size
(
0
)
topk
=
indices
.
size
(
-
1
)
def
log2sumexp2
(
a
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
return
torch
.
logsumexp
(
a
*
math
.
log
(
2
),
dim
=
dim
)
*
math
.
log2
(
math
.
e
)
indices
=
indices
[:,
0
,
:]
# [s_q, topk]
invalid_indices_mask
=
(
indices
<
0
)
|
(
indices
>=
s_kv
)
qs
=
q
.
float
()
# [s_q, h_q, d_qk]
kvs
=
kv
[
:,
0
,
:].
float
()
# [s_kv, d_qk]
kvs
=
torch
.
index_select
(
kvs
,
0
,
indices
.
masked_fill
(
invalid_indices_mask
,
0
).
flatten
()).
view
(
s_q
,
topk
,
d_qk
)
# [s_q, topk, d_qk]
attn_score
=
qs
@
kvs
.
transpose
(
1
,
2
)
# [s_q, h_q, topk]
attn_score
.
masked_fill_
(
invalid_indices_mask
.
unsqueeze
(
1
),
float
(
'-inf'
))
attn_score
*=
sm_scale
*
math
.
log2
(
math
.
e
)
max_logits
=
torch
.
max
(
attn_score
,
dim
=-
1
)[
0
]
# [s_q, h_q]
lse
=
log2sumexp2
(
attn_score
,
dim
=-
1
)
# [s_q, h_q]
attn_score
=
torch
.
exp2
(
attn_score
-
lse
.
unsqueeze
(
-
1
))
# [s_q, h_q, topk]
result
=
attn_score
@
kvs
[:,
:,
:
d_v
]
return
(
max_logits
,
lse
,
result
)
def
native_mla_with_kvcache
(
q
:
torch
.
Tensor
,
# [batch_size, s_q, h_q, d]
blocked_k
:
torch
.
Tensor
,
# [?, block_size, h_kv, d]
block_table
:
torch
.
Tensor
,
# [batch_size, ?]
cache_seqlens
:
torch
.
Tensor
,
# [batch_size]
dv
:
int
,
is_causal
:
bool
,
indices
:
Optional
[
torch
.
Tensor
]
=
None
# [batch_size, s_q, topk]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
A reference implementation in PyTorch
"""
def
get_topk_attn_mask
(
s_q
:
int
,
s_k
:
int
,
indices
:
torch
.
Tensor
):
mask
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
torch
.
bool
)
for
i
in
range
(
s_q
):
cur_indices
=
indices
[
i
]
valid_indices
=
cur_indices
[
cur_indices
!=
-
1
]
mask
[
i
,
valid_indices
]
=
True
return
mask
def
scaled_dot_product_attention
(
batch_idx
:
int
,
query
:
torch
.
Tensor
,
# [h_q, s_q, d]
kv
:
torch
.
Tensor
,
# [h_kv, s_k, d]
dv
:
int
,
is_causal
,
indices
:
Optional
[
torch
.
Tensor
],
# [s_q, topk]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
h_q
=
query
.
size
(
0
)
h_kv
=
kv
.
size
(
0
)
s_q
=
query
.
shape
[
-
2
]
s_k
=
kv
.
shape
[
-
2
]
query
=
query
.
float
()
kv
=
kv
.
float
()
if
h_kv
!=
1
:
kv
=
kv
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
kv
[
kv
!=
kv
]
=
0.0
attn_weight
=
query
@
kv
.
transpose
(
-
2
,
-
1
)
# [h_q, s_q, s_k]
if
(
is_causal
and
query
.
size
(
1
)
>
1
)
or
indices
is
not
None
:
mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
)
if
is_causal
:
assert
indices
is
None
mask
=
mask
.
tril
(
diagonal
=
s_k
-
s_q
)
if
indices
is
not
None
:
mask
&=
get_topk_attn_mask
(
s_q
,
s_k
,
indices
)
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
torch
.
float
)
attn_bias
.
masked_fill_
(
mask
.
logical_not
(),
float
(
"-inf"
))
attn_weight
+=
attn_bias
.
to
(
q
.
dtype
)
attn_weight
/=
math
.
sqrt
(
query
.
size
(
-
1
))
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
# [h_q, s_q]
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
output
=
attn_weight
@
kv
[...,
:
dv
]
# [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask
=
(
lse
==
float
(
"-inf"
))
output
[
lonely_q_mask
.
unsqueeze
(
-
1
).
broadcast_to
(
h_q
,
s_q
,
dv
)]
=
0.0
lse
[
lonely_q_mask
]
=
float
(
"+inf"
)
return
output
,
lse
b
,
s_q
,
h_q
,
d
=
q
.
size
()
block_size
=
blocked_k
.
size
(
1
)
h_kv
=
blocked_k
.
size
(
2
)
cache_seqlens_cpu
=
cache_seqlens
.
cpu
()
out_ref
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse_ref
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
cur_len
=
cache_seqlens_cpu
[
i
].
item
()
cur_num_blocks
=
cdiv
(
cur_len
,
block_size
)
cur_block_indices
=
block_table
[
i
][
0
:
cur_num_blocks
]
cur_kv
=
blocked_k
[
cur_block_indices
].
view
(
-
1
,
h_kv
,
d
)[:
cur_len
,
...]
cur_out
,
cur_lse
=
scaled_dot_product_attention
(
i
,
q
[
i
].
transpose
(
0
,
1
),
cur_kv
.
transpose
(
0
,
1
),
dv
,
is_causal
,
indices
[
i
]
if
indices
is
not
None
else
None
)
out_ref
[
i
]
=
cur_out
.
transpose
(
0
,
1
)
lse_ref
[
i
]
=
cur_lse
out_ref
=
out_ref
.
to
(
torch
.
bfloat16
)
return
out_ref
,
lse_ref
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
0 → 100644
View file @
7993ed8d
# fallback_fp8.py
# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops
from
sglang.srt.utils
import
ceil_div
import
torch
@
torch
.
no_grad
()
def
fallback_fp8_mqa_logits
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
ks
:
torch
.
Tensor
,
ke
:
torch
.
Tensor
,
cost_only
:
bool
=
False
)
->
torch
.
Tensor
:
seq_len_kv
=
kv
.
shape
[
0
]
if
cost_only
:
start
=
ks
.
clamp
(
min
=
0
,
max
=
seq_len_kv
)
end
=
ke
.
clamp
(
min
=
0
,
max
=
seq_len_kv
)
count_ones_per_row
=
(
end
-
start
).
clamp
(
min
=
0
)
return
count_ones_per_row
.
sum
()
k
=
kv
q
=
q
.
float
()
k
=
k
.
float
()
mask_lo
=
torch
.
arange
(
0
,
seq_len_kv
,
device
=
'cuda'
)[
None
,
:]
>=
ks
[:,
None
]
mask_hi
=
torch
.
arange
(
0
,
seq_len_kv
,
device
=
'cuda'
)[
None
,
:]
<
ke
[:,
None
]
mask
=
mask_lo
&
mask_hi
score
=
torch
.
einsum
(
'mhd,nd->hmn'
,
q
,
k
)
logits
=
(
score
.
relu
()
*
weights
.
unsqueeze
(
-
1
).
transpose
(
0
,
1
)).
sum
(
dim
=
0
)
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
'-inf'
))
#cost = mask.sum()
return
logits
# """
# PyTorch fallback for fp8_mqa_logits.
# No real fp8 used, just FP32.
# Args:
# q: (M, H, D) query
# k: (N, D) key
# weights: (M, H)
# ks: (M,) int32
# ke: (M,) int32
# Returns:
# logits: (M, N) with -inf outside of valid range
# """
# M, H, D = q.shape
# N = k[0].shape[0]
# logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device)
# # for i in range(M):
# # start = max(ks[i].item(), 0)
# # end = min(ke[i].item(), N)
# # if start >= end:
# # continue
# # qi = q[i] # (H, D)
# # ki = k[start:end] # (L, D)
# # sim = torch.matmul(qi, ki.T) # (H, L)
# # weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0) # (L,)
# # logits[i, start:end] = weighted_sim
# return logits
@
torch
.
no_grad
()
def
fallback_fp8_paged_mqa_logits
(
q
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
max_model_len
:
int
)
->
torch
.
Tensor
:
batch_size
,
next_n
,
heads
,
dim
=
q
.
size
()
num_block
,
block_size
,
_
,
dim
=
kv_cache
.
size
()
logits
=
torch
.
full
([
batch_size
*
next_n
,
max_model_len
],
float
(
'-inf'
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
context_lens
=
context_lens
.
tolist
()
for
i
in
range
(
batch_size
):
context_len
=
context_lens
[
i
]
q_offsets
=
torch
.
arange
(
context_len
-
next_n
,
context_len
,
device
=
q
.
device
)
weight_slice
=
weights
[
i
*
next_n
:(
i
+
1
)
*
next_n
,
:].
transpose
(
0
,
1
).
contiguous
()
for
block_rk
in
range
(
ceil_div
(
context_len
,
block_size
)):
block_idx
=
block_tables
[
i
][
block_rk
]
qx
,
kx
=
q
[
i
],
kv_cache
[
block_idx
]
k_offsets
=
torch
.
arange
(
block_rk
*
block_size
,
(
block_rk
+
1
)
*
block_size
,
device
=
q
.
device
)
mask
=
(
k_offsets
[
None
,
:]
<
context_len
)
&
(
k_offsets
[
None
,
:]
<=
q_offsets
[:,
None
])
s
=
torch
.
where
(
mask
[
None
,
:,
:],
(
qx
.
transpose
(
0
,
1
)
@
kx
.
transpose
(
0
,
1
).
transpose
(
1
,
2
)).
to
(
logits
.
dtype
),
float
(
'-inf'
))
s
=
torch
.
relu
(
s
)
*
weight_slice
[...,
None
]
s
=
s
.
sum
(
dim
=
0
)
logits
[
i
*
next_n
:(
i
+
1
)
*
next_n
,
block_rk
*
block_size
:
(
block_rk
+
1
)
*
block_size
]
=
torch
.
where
(
k_offsets
[
None
,
:]
<=
q_offsets
[:,
None
],
s
,
float
(
'-inf'
))
return
logits
"""
PyTorch fallback for fp8_paged_mqa_logits.
No real fp8 used, just FP32.
Args:
q: (B, N, H, D)
kv_cache: (num_blocks, block_size, 1, D)
weights: (B * N, H)
context_lens: (B,)
block_tables: (B, max_blocks)
max_model_len: int
Returns:
logits: (B * N, max_model_len)
"""
B
,
N
,
H
,
D
=
q
.
shape
block_size
=
kv_cache
.
shape
[
1
]
logits
=
torch
.
full
((
B
*
N
,
max_model_len
),
float
(
"-inf"
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
for
i
in
range
(
B
):
ctx_len
=
context_lens
[
i
].
item
()
q_offsets
=
torch
.
arange
(
ctx_len
-
N
,
ctx_len
,
device
=
q
.
device
)
weight_slice
=
weights
[
i
*
N
:(
i
+
1
)
*
N
,
:].
transpose
(
0
,
1
).
contiguous
()
for
br
in
range
((
ctx_len
+
block_size
-
1
)
//
block_size
):
blk_idx
=
block_tables
[
i
,
br
].
item
()
if
blk_idx
<
0
:
continue
qx
=
q
[
i
]
# (N, H, D)
kx
=
kv_cache
[
blk_idx
]
# (block_size, 1, D)
kx
=
kx
.
squeeze
(
1
)
# (block_size, D)
k_offsets
=
torch
.
arange
(
br
*
block_size
,
(
br
+
1
)
*
block_size
,
device
=
q
.
device
)
mask
=
(
k_offsets
[
None
,
:]
<
ctx_len
)
&
(
k_offsets
[
None
,
:]
<=
q_offsets
[:,
None
])
# (N, block_size)
s
=
torch
.
where
(
mask
[
None
,
:,
:],
torch
.
einsum
(
'nhd,ld->hnl'
,
qx
,
kx
),
torch
.
full
((
H
,
N
,
block_size
),
float
(
"-inf"
),
device
=
q
.
device
))
s
=
s
.
relu
()
*
weight_slice
[...,
None
]
logits_slice
=
s
.
sum
(
dim
=
0
)
# (N, block_size)
mask_block
=
(
k_offsets
[
None
,
:]
<=
q_offsets
[:,
None
])
logits
[
i
*
N
:(
i
+
1
)
*
N
,
br
*
block_size
:(
br
+
1
)
*
block_size
]
=
\
torch
.
where
(
mask_block
,
logits_slice
,
float
(
"-inf"
))
return
logits
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
7993ed8d
...
...
@@ -3,6 +3,7 @@ from __future__ import annotations
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Tuple
from
sglang.srt.layers.attention.nsa.fallback_fp8
import
fallback_fp8_mqa_logits
,
fallback_fp8_paged_mqa_logits
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
...
...
@@ -14,7 +15,7 @@ from sglang.srt.utils import add_prefix, is_npu
if
not
is_npu
():
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
import
deep_gemm
#
import deep_gemm
from
sglang.srt.layers.attention.nsa.utils
import
NSA_DUAL_STREAM
,
NSA_USE_REAL_INDEXER
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
...
...
@@ -27,14 +28,14 @@ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
add_prefix
,
align
,
is_cuda
try
:
import
deep_gemm_v32
except
ImportError
as
e
:
print
(
"Error when importing deep_gemm_v32, try deep_gemm"
)
try
:
import
deep_gemm
as
deep_gemm_v32
except
ImportError
as
e
:
print
(
"Error when importing deep_gemm, skip"
)
#
try:
#
import deep_gemm_v32
#
except ImportError as e:
#
print("Error when importing deep_gemm_v32, try deep_gemm")
#
try:
#
import deep_gemm as deep_gemm_v32
#
except ImportError as e:
#
print("Error when importing deep_gemm, skip")
if
TYPE_CHECKING
:
...
...
@@ -81,16 +82,47 @@ class BaseIndexerMetadata(ABC):
Don't assume it is the topk indices of the input logits.
"""
def
hadamard_transform_pytorch
(
x
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
"""
A native PyTorch implementation of the Fast Hadamard Transform that mimics
the behavior of the custom CUDA kernel's call signature.
Args:
x (torch.Tensor): Input tensor of shape (*, N), where N is a power of 2.
scale (float): The normalization factor to multiply the result by.
Returns:
torch.Tensor: The Hadamard transformed tensor.
"""
# Base case for recursion
if
x
.
shape
[
-
1
]
==
1
:
return
x
# Split the tensor into two halves
half_size
=
x
.
shape
[
-
1
]
//
2
a
=
x
[...,
:
half_size
]
b
=
x
[...,
half_size
:]
# Recursive calls
a_transformed
=
hadamard_transform_pytorch
(
a
,
scale
=
1.0
)
# No scaling in intermediate steps
b_transformed
=
hadamard_transform_pytorch
(
b
,
scale
=
1.0
)
# No scaling in intermediate steps
# Combine the results
combined
=
torch
.
cat
([
a_transformed
+
b_transformed
,
a_transformed
-
b_transformed
],
dim
=-
1
)
# Apply the scale only at the final step
return
combined
*
scale
def
rotate_activation
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
x
.
dtype
==
torch
.
bfloat16
from
fast_hadamard_transform
import
hadamard_transform
#
from fast_hadamard_transform import hadamard_transform
hidden_size
=
x
.
size
(
-
1
)
assert
(
hidden_size
&
(
hidden_size
-
1
)
)
==
0
,
"Hidden size must be a power of 2 for Hadamard transform."
return
hadamard_transform
(
x
,
scale
=
hidden_size
**-
0.5
)
return
hadamard_transform
_pytorch
(
x
,
scale
=
hidden_size
**-
0.5
)
class
V32LayerNorm
(
nn
.
Module
):
...
...
@@ -140,7 +172,7 @@ class Indexer(CustomOp):
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
if
not
is_npu
():
self
.
sm_count
=
deep_gemm
.
get_num_sms
()
self
.
sm_count
=
torch
.
cuda
.
get_device_properties
(
0
).
multi_processor_count
self
.
half_device_sm_count
=
align
(
self
.
sm_count
//
2
,
8
)
self
.
wq_b
=
ReplicatedLinear
(
...
...
@@ -273,9 +305,7 @@ class Indexer(CustomOp):
k_rope
,
_
=
torch
.
split
(
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
)
q_rope
,
k_rope
=
self
.
rotary_emb
(
positions
,
q_rope
,
k_rope
)
query
[...,
:
self
.
rope_head_dim
]
=
q_rope
key
[...,
:
self
.
rope_head_dim
]
=
k_rope
...
...
@@ -323,9 +353,9 @@ class Indexer(CustomOp):
blocksize
=
page_size
seqlens_32
=
metadata
.
get_seqlens_int32
()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata
=
deep_gemm_v32
.
get_paged_mqa_logits_metadata
(
seqlens_32
,
blocksize
,
self
.
sm_count
)
#
schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
#
seqlens_32, blocksize, self.sm_count
#
)
assert
len
(
q_fp8
.
shape
)
==
3
q_fp8
=
q_fp8
.
unsqueeze
(
1
)
# the next_n dim is 1 now
...
...
@@ -339,15 +369,13 @@ class Indexer(CustomOp):
assert
len
(
weights
.
shape
)
==
3
weights
=
weights
.
squeeze
(
2
)
logits
=
deep_gemm_v32
.
fp8_paged_mqa_logits
(
logits
=
fallback_
fp8_paged_mqa_logits
(
q_fp8
,
kv_cache_fp8
,
weights
,
seqlens_32
,
block_tables
,
schedule_metadata
,
max_seq_len
,
clean_logits
=
False
,
)
# NOTE(dark): logits should be cleaned in topk_transform
...
...
@@ -408,13 +436,12 @@ class Indexer(CustomOp):
seq_lens_expanded
=
metadata
.
get_seqlens_expanded
()
ke
=
ks
+
seq_lens_expanded
logits
=
deep_gemm_v32
.
fp8_mqa_logits
(
logits
=
fallback_
fp8_mqa_logits
(
q_fp8
,
k
v
_fp8
,
k_fp8
,
weights
,
ks
,
ke
,
clean_logits
=
False
,
ke
)
assert
logits
.
shape
[
0
]
==
len
(
seq_lens_expanded
)
...
...
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
View file @
7993ed8d
from
typing
import
Optional
,
Tuple
import
tilelang
import
tilelang.language
as
T
#
import tilelang
#
import tilelang.language as T
import
torch
tilelang
.
set_log_level
(
"WARNING"
)
#
tilelang.set_log_level("WARNING")
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_FAST_MATH
:
True
,
}
#
pass_configs = {
#
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
#
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
#
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
#
}
BF16
=
"bfloat16"
FP8
=
"float8_e4m3"
FP32
=
"float32"
'''
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
...
...
@@ -32,7 +32,6 @@ def fast_pow2(x):
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
...
...
@@ -83,7 +82,6 @@ def act_quant_kernel(
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
...
...
@@ -753,7 +751,6 @@ def sparse_attention_fwd_kernel_v2(
return main
def tilelang_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
...
...
@@ -772,3 +769,45 @@ def tilelang_sparse_fwd(
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
)
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
'''
def
act_quant
(
x
:
torch
.
Tensor
,
block_size
:
int
=
128
,
scale_fmt
:
Optional
[
str
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
PyTorch fallback for act_quant
Block-wise FP8 E4M3 quantization
"""
if
not
x
.
is_contiguous
():
x
=
x
.
contiguous
()
N
=
x
.
size
(
-
1
)
assert
N
%
block_size
==
0
,
f
"Last dim
{
N
}
must be divisible by block_size=
{
block_size
}
"
# Reshape to blocks
x_2d
=
x
.
view
(
-
1
,
N
)
x_blocks
=
x_2d
.
view
(
-
1
,
block_size
)
# Compute absmax per block
amax
=
x_blocks
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-4
)
# FP8 E4M3 max value is ~448
fp8_max
=
448.0
scale
=
amax
/
fp8_max
if
scale_fmt
is
not
None
:
# Simulate rounded scale (power-of-2 rounding)
scale
=
torch
.
round
(
scale
*
256
)
/
256
# Quantize and clamp
y_blocks
=
torch
.
clamp
(
torch
.
round
(
x_blocks
/
scale
),
-
fp8_max
,
fp8_max
)
# Convert to FP8
q
=
y_blocks
.
view_as
(
x_2d
).
to
(
torch
.
float8_e4m3fn
)
# Reshape scale
s
=
scale
.
view
(
x_2d
.
size
(
0
),
N
//
block_size
).
to
(
torch
.
float32
)
s
=
s
.
view
(
*
x
.
shape
[:
-
1
],
N
//
block_size
)
return
q
.
view_as
(
x
),
s
python/sglang/srt/layers/attention/nsa/transform_index.py
View file @
7993ed8d
...
...
@@ -105,7 +105,7 @@ def transform_index_page_table_decode_ref(
torch
.
gather
(
page_table
,
dim
=
1
,
index
=
topk_indices
.
clamp
(
min
=
0
),
index
=
topk_indices
.
clamp
(
min
=
0
)
.
long
()
,
out
=
result
,
)
result
[
topk_indices
<
0
]
=
-
1
...
...
python/sglang/srt/layers/attention/nsa_backend.py
View file @
7993ed8d
...
...
@@ -10,7 +10,6 @@ from typing import (
Tuple
,
TypeAlias
,
Union
,
override
,
)
import
torch
...
...
@@ -101,19 +100,15 @@ class NSAMetadata:
class
NSAIndexerMetadata
(
BaseIndexerMetadata
):
attn_metadata
:
NSAMetadata
@
override
def
get_seqlens_int32
(
self
)
->
torch
.
Tensor
:
return
self
.
attn_metadata
.
cache_seqlens_int32
@
override
def
get_page_table_64
(
self
)
->
torch
.
Tensor
:
return
self
.
attn_metadata
.
real_page_table
@
override
def
get_seqlens_expanded
(
self
)
->
torch
.
Tensor
:
return
self
.
attn_metadata
.
nsa_seqlens_expanded
@
override
def
topk_transform
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -524,21 +519,25 @@ class NativeSparseAttnBackend(AttentionBackend):
extend_lens_cpu
=
metadata
.
nsa_extend_seq_lens_list
,
page_size
=
1
,
)
if
NSA_PREFILL_IMPL
==
"tilelang"
:
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
(
tilelang_sparse_fwd
,
)
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_tilelang
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
page_table_1
=
page_table_1
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_PREFILL_IMPL
==
"flashmla_prefill"
:
# if NSA_PREFILL_IMPL == "tilelang":
# from sglang.srt.layers.attention.nsa.tilelang_kernel import (
# tilelang_sparse_fwd,
# )
# if q_rope is not None:
# q_all = torch.cat([q_nope, q_rope], dim=-1)
# return self._forward_tilelang(
# q_all=q_all,
# kv_cache=kv_cache,
# page_table_1=page_table_1,
# sm_scale=layer.scaling,
# v_head_dim=layer.v_head_dim,
# )
# elif NSA_PREFILL_IMPL == "flashmla_prefill":
# Skip tilelang dependencies
if
NSA_PREFILL_IMPL
==
"tilelang"
or
NSA_PREFILL_IMPL
==
"flashmla_prefill"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_prefill
(
...
...
@@ -733,9 +732,9 @@ class NativeSparseAttnBackend(AttentionBackend):
page_table_1
:
torch
.
Tensor
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
from
flash_mla
import
flash_mla_sparse_fwd
o
,
_
,
_
=
flash
_mla_sparse_fwd
(
#
from flash_mla import flash_mla_sparse_fwd
from
sglang.srt.layers.attention.native_mla
import
native_mla_sparse_fwd
_
,
_
,
o
=
native
_mla_sparse_fwd
(
q
=
q_all
,
kv
=
kv_cache
,
indices
=
page_table_1
.
unsqueeze
(
1
),
...
...
@@ -756,8 +755,8 @@ class NativeSparseAttnBackend(AttentionBackend):
topk_indices
,
block_table
,
)
->
torch
.
Tensor
:
from
flash_mla
import
flash_mla_with_kvcache
#
from flash_mla import flash_mla_with_kvcache
from
sglang.srt.layers.attention.native_mla
import
native_mla_with_kvcache
cache_seqlens
=
metadata
.
nsa_cache_seqlens_int32
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
...
...
@@ -769,7 +768,7 @@ class NativeSparseAttnBackend(AttentionBackend):
# inefficiently quantize the whole cache
kv_cache
=
quantize_k_cache
(
kv_cache
)
o
,
_
=
flash
_mla_with_kvcache
(
o
,
_
=
native
_mla_with_kvcache
(
q
=
q_all
,
k_cache
=
kv_cache
,
cache_seqlens
=
cache_seqlens
,
...
...
python/sglang/srt/layers/layernorm.py
View file @
7993ed8d
...
...
@@ -136,21 +136,21 @@ class RMSNorm(CustomOp):
# NOTE: Remove this if aiter kernel supports discontinuous input
x
=
x
.
contiguous
()
if
residual
is
not
None
:
if
_vllm_version
<
Version
(
"0.9"
):
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
else
:
residual_out
=
torch
.
empty_like
(
x
)
output
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
output
,
x
,
residual_out
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
output
,
residual_out
#
if _vllm_version < Version("0.9"):
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
#
else:
#
residual_out = torch.empty_like(x)
#
output = torch.empty_like(x)
#
fused_add_rms_norm(
#
output,
#
x,
#
residual_out,
#
residual,
#
self.weight.data,
#
self.variance_epsilon,
#
)
#
return output, residual_out
out
=
torch
.
empty_like
(
x
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
7993ed8d
...
...
@@ -765,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
cos_for_key
=
cos
[:,
0
,
...]
sin_for_key
=
sin
[:,
0
,
...]
key_rot
=
key_rot
*
cos_for_key
+
rotate_fn
(
key_rot
)
*
sin_for_key
#key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment