Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
451d15c4
"vscode:/vscode.git/clone" did not exist on "279f744ce560818544a018b21c126ce18dc41d86"
Unverified
Commit
451d15c4
authored
Oct 10, 2025
by
Binyao Jiang
Committed by
GitHub
Oct 10, 2025
Browse files
[DPSKv3.2] Rewrite nsa tilelang act_quant kernel to triton (#11450)
parent
c80a96da
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
420 additions
and
1 deletion
+420
-1
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+3
-1
python/sglang/srt/layers/attention/nsa/triton_kernel.py
python/sglang/srt/layers/attention/nsa/triton_kernel.py
+136
-0
test/srt/layers/attention/nsa/test_act_quant_triton.py
test/srt/layers/attention/nsa/test_act_quant_triton.py
+281
-0
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
451d15c4
...
@@ -505,8 +505,10 @@ class Indexer(CustomOp):
...
@@ -505,8 +505,10 @@ class Indexer(CustomOp):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
layer_id
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
if
not
is_npu
():
if
is_hip
():
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
elif
not
is_npu
():
from
sglang.srt.layers.attention.nsa.triton_kernel
import
act_quant
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
assert
isinstance
(
forward_batch
.
token_to_kv_pool
,
NSATokenToKVPool
)
assert
isinstance
(
forward_batch
.
token_to_kv_pool
,
NSATokenToKVPool
)
...
...
python/sglang/srt/layers/attention/nsa/triton_kernel.py
0 → 100644
View file @
451d15c4
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
# Triton implementation
@
triton
.
jit
def
_act_quant_kernel
(
X_ptr
,
Y_ptr
,
S_ptr
,
M
,
N
,
group_size
:
tl
.
constexpr
,
round_scale
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""
Triton kernel for activation quantization.
Each block processes BLOCK_M rows and group_size columns.
"""
# Get block IDs
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
# FP8 constants
fp8_min
=
-
448.0
fp8_max
=
448.0
fp8_max_inv
=
1.0
/
fp8_max
# Calculate row and column offsets
row_start
=
pid_m
*
BLOCK_M
col_start
=
pid_n
*
group_size
# Create offset arrays
rows
=
row_start
+
tl
.
arange
(
0
,
BLOCK_M
)
cols
=
col_start
+
tl
.
arange
(
0
,
BLOCK_N
)
# Mask for valid rows and columns
row_mask
=
rows
<
M
col_mask
=
cols
<
N
mask
=
row_mask
[:,
None
]
&
col_mask
[
None
,
:]
# Load input data
x_ptrs
=
X_ptr
+
rows
[:,
None
]
*
N
+
cols
[
None
,
:]
x
=
tl
.
load
(
x_ptrs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Compute absolute max along columns (group_size dimension) for each row
x_abs
=
tl
.
abs
(
x
)
amax
=
tl
.
max
(
x_abs
,
axis
=
1
)
# Shape: (BLOCK_M,)
# Clamp amax to avoid division by zero
amax
=
tl
.
maximum
(
amax
,
1e-4
)
# Compute scale
if
round_scale
:
# Fast round scale using bit manipulation approximation
# This is a simplified version - the exact bit manipulation is harder in Triton
# Using log2 + ceil + pow2 as approximation
log_val
=
tl
.
log2
(
amax
*
fp8_max_inv
)
log_ceil
=
tl
.
ceil
(
log_val
)
scale
=
tl
.
exp2
(
log_ceil
)
else
:
scale
=
amax
*
fp8_max_inv
# Quantize: y = clamp(x / scale, fp8_min, fp8_max)
scale_broadcast
=
scale
[:,
None
]
y
=
x
/
scale_broadcast
y
=
tl
.
minimum
(
tl
.
maximum
(
y
,
fp8_min
),
fp8_max
)
# Store quantized output
y_ptrs
=
Y_ptr
+
rows
[:,
None
]
*
N
+
cols
[
None
,
:]
tl
.
store
(
y_ptrs
,
y
,
mask
=
mask
)
# Store scales
s_cols
=
pid_n
s_ptrs
=
S_ptr
+
rows
*
(
N
//
group_size
)
+
s_cols
s_mask
=
row_mask
tl
.
store
(
s_ptrs
,
scale
,
mask
=
s_mask
)
def
act_quant
(
x
:
torch
.
Tensor
,
block_size
:
int
=
128
,
scale_fmt
:
Optional
[
str
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantizes the input tensor `x` using block-wise quantization with Triton.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert
x
.
is_contiguous
(),
"Input tensor must be contiguous"
assert
(
x
.
size
(
-
1
)
%
block_size
==
0
),
f
"Last dimension size must be divisible by block_size (block_size=
{
block_size
}
)"
# Flatten all dims except last
N
=
x
.
size
(
-
1
)
x_flat
=
x
.
view
(
-
1
,
N
)
M
=
x_flat
.
size
(
0
)
# Allocate output tensors
y
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
float8_e4m3fn
)
y_flat
=
y
.
view
(
-
1
,
N
)
s
=
x
.
new_empty
(
*
x
.
size
()[:
-
1
],
N
//
block_size
,
dtype
=
torch
.
float32
)
s_flat
=
s
.
view
(
-
1
,
N
//
block_size
)
# Launch kernel
BLOCK_M
=
32
BLOCK_N
=
block_size
grid
=
(
triton
.
cdiv
(
M
,
BLOCK_M
),
triton
.
cdiv
(
N
,
block_size
))
round_scale
=
scale_fmt
is
not
None
_act_quant_kernel
[
grid
](
x_flat
,
y_flat
,
s_flat
,
M
,
N
,
group_size
=
block_size
,
round_scale
=
round_scale
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
num_stages
=
0
if
round_scale
else
2
,
)
return
y
,
s
test/srt/layers/attention/nsa/test_act_quant_triton.py
0 → 100644
View file @
451d15c4
"""
Unit tests comparing TileLang and Triton implementations of activation quantization.
Tests both accuracy and performance.
"""
import
time
from
typing
import
Tuple
import
pytest
import
torch
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
from
sglang.srt.layers.attention.nsa.triton_kernel
import
act_quant
as
act_quant_triton
def
benchmark_kernel
(
fn
,
x
:
torch
.
Tensor
,
block_size
:
int
,
scale_fmt
,
warmup
:
int
=
10
,
repeat
:
int
=
100
,
use_cuda_graph
:
bool
=
True
,
)
->
Tuple
[
float
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Benchmark a kernel function.
Args:
fn: Function to benchmark
x: Input tensor
block_size: Block size for quantization
scale_fmt: Scale format
warmup: Number of warmup iterations
repeat: Number of repeat iterations
use_cuda_graph: Whether to use CUDA graphs for more accurate timing
Returns:
Tuple of (avg_time_ms, quantized_output, scales)
"""
# Warmup
for
_
in
range
(
warmup
):
y
,
s
=
fn
(
x
,
block_size
=
block_size
,
scale_fmt
=
scale_fmt
)
if
not
x
.
is_cuda
or
not
use_cuda_graph
:
# Fallback to regular timing
if
x
.
is_cuda
:
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
repeat
):
y
,
s
=
fn
(
x
,
block_size
=
block_size
,
scale_fmt
=
scale_fmt
)
if
x
.
is_cuda
:
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
avg_time_ms
=
(
end
-
start
)
/
repeat
*
1000
return
avg_time_ms
,
y
,
s
# Use CUDA graph for more accurate timing
torch
.
cuda
.
synchronize
()
# Allocate output buffers
N
=
x
.
size
(
-
1
)
y
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
float8_e4m3fn
)
s
=
x
.
new_empty
(
*
x
.
size
()[:
-
1
],
N
//
block_size
,
dtype
=
torch
.
float32
)
# Capture CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
y_cap
,
s_cap
=
fn
(
x
,
block_size
=
block_size
,
scale_fmt
=
scale_fmt
)
# Warmup with graph
for
_
in
range
(
warmup
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
# Timing with CUDA graph
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
.
record
()
for
_
in
range
(
repeat
):
graph
.
replay
()
end_event
.
record
()
torch
.
cuda
.
synchronize
()
avg_time_ms
=
start_event
.
elapsed_time
(
end_event
)
/
repeat
return
avg_time_ms
,
y_cap
,
s_cap
def
check_accuracy
(
y_ref
:
torch
.
Tensor
,
s_ref
:
torch
.
Tensor
,
y_test
:
torch
.
Tensor
,
s_test
:
torch
.
Tensor
,
rtol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
)
->
Tuple
[
bool
,
dict
]:
"""
Check accuracy between reference and test outputs.
Args:
y_ref: Reference quantized output
s_ref: Reference scales
y_test: Test quantized output
s_test: Test scales
rtol: Relative tolerance
atol: Absolute tolerance
Returns:
Tuple of (passed, metrics_dict)
"""
# Convert FP8 to float for comparison
y_ref_float
=
y_ref
.
float
()
y_test_float
=
y_test
.
float
()
# Compute differences
y_diff
=
torch
.
abs
(
y_ref_float
-
y_test_float
)
s_diff
=
torch
.
abs
(
s_ref
-
s_test
)
# Compute metrics
y_max_diff
=
y_diff
.
max
().
item
()
y_mean_diff
=
y_diff
.
mean
().
item
()
s_max_diff
=
s_diff
.
max
().
item
()
s_mean_diff
=
s_diff
.
mean
().
item
()
# Check relative and absolute tolerance
y_close
=
torch
.
allclose
(
y_ref_float
,
y_test_float
,
rtol
=
rtol
,
atol
=
atol
)
s_close
=
torch
.
allclose
(
s_ref
,
s_test
,
rtol
=
rtol
,
atol
=
atol
)
# Compute percentage of matching elements
y_match_pct
=
(
y_ref_float
==
y_test_float
).
float
().
mean
().
item
()
*
100
metrics
=
{
"y_max_diff"
:
y_max_diff
,
"y_mean_diff"
:
y_mean_diff
,
"y_match_pct"
:
y_match_pct
,
"s_max_diff"
:
s_max_diff
,
"s_mean_diff"
:
s_mean_diff
,
"y_close"
:
y_close
,
"s_close"
:
s_close
,
}
passed
=
y_close
and
s_close
return
passed
,
metrics
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA not available"
)
def
test_act_quant_comprehensive_benchmark
(
scale_fmt
=
None
):
"""Comprehensive benchmark across multiple sizes with CUDA graphs."""
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
block_size
=
128
shapes
=
[
(
128
,
512
),
(
256
,
1024
),
(
512
,
2048
),
(
1024
,
4096
),
(
2048
,
8192
),
(
4096
,
16384
),
]
print
(
"
\n
"
+
"="
*
100
)
print
(
"Comprehensive Performance Benchmark with CUDA Graphs"
)
print
(
"="
*
100
)
print
(
f
"
{
'Shape'
:
<
20
}
{
'TileLang (ms)'
:
<
15
}
{
'Triton (ms)'
:
<
15
}
{
'Speedup'
:
<
10
}
{
'Status'
}
"
)
print
(
"-"
*
100
)
for
shape
in
shapes
:
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
device
)
try
:
# Benchmark both with CUDA graphs
time_tilelang
,
y_ref
,
s_ref
=
benchmark_kernel
(
act_quant
,
x
,
block_size
,
scale_fmt
,
warmup
=
5
,
repeat
=
50
,
use_cuda_graph
=
True
,
)
time_triton
,
y_triton
,
s_triton
=
benchmark_kernel
(
act_quant_triton
,
x
,
block_size
,
scale_fmt
,
warmup
=
5
,
repeat
=
50
,
use_cuda_graph
=
True
,
)
# Check accuracy
passed
,
_
=
check_accuracy
(
y_ref
,
s_ref
,
y_triton
,
s_triton
)
speedup
=
time_tilelang
/
time_triton
if
time_triton
>
0
else
0
status
=
"✓ PASS"
if
passed
else
"✗ FAIL"
print
(
f
"
{
str
(
shape
):
<
20
}
{
time_tilelang
:
<
15.4
f
}
{
time_triton
:
<
15.4
f
}
"
f
"
{
speedup
:
<
10.2
f
}
{
status
}
"
)
except
Exception
as
e
:
print
(
f
"
{
str
(
shape
):
<
20
}
ERROR:
{
str
(
e
)
}
"
)
print
(
"="
*
100
)
# Also run without CUDA graphs for comparison
print
(
"
\n
"
+
"="
*
100
)
print
(
"Performance Benchmark WITHOUT CUDA Graphs (for comparison)"
)
print
(
"="
*
100
)
print
(
f
"
{
'Shape'
:
<
20
}
{
'TileLang (ms)'
:
<
15
}
{
'Triton (ms)'
:
<
15
}
{
'Speedup'
:
<
10
}
{
'Status'
}
"
)
print
(
"-"
*
100
)
for
shape
in
shapes
:
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
device
)
try
:
# Benchmark both without CUDA graphs
time_tilelang
,
y_ref
,
s_ref
=
benchmark_kernel
(
act_quant
,
x
,
block_size
,
scale_fmt
,
warmup
=
5
,
repeat
=
50
,
use_cuda_graph
=
False
,
)
time_triton
,
y_triton
,
s_triton
=
benchmark_kernel
(
act_quant_triton
,
x
,
block_size
,
scale_fmt
,
warmup
=
5
,
repeat
=
50
,
use_cuda_graph
=
False
,
)
# Check accuracy
passed
,
_
=
check_accuracy
(
y_ref
,
s_ref
,
y_triton
,
s_triton
)
speedup
=
time_tilelang
/
time_triton
if
time_triton
>
0
else
0
status
=
"✓ PASS"
if
passed
else
"✗ FAIL"
print
(
f
"
{
str
(
shape
):
<
20
}
{
time_tilelang
:
<
15.4
f
}
{
time_triton
:
<
15.4
f
}
"
f
"
{
speedup
:
<
10.2
f
}
{
status
}
"
)
except
Exception
as
e
:
print
(
f
"
{
str
(
shape
):
<
20
}
ERROR:
{
str
(
e
)
}
"
)
print
(
"="
*
100
)
if
__name__
==
"__main__"
:
# Run comprehensive benchmark
if
torch
.
cuda
.
is_available
():
print
(
"
\n
"
+
"="
*
80
)
print
(
"Running Comprehensive Benchmark with scale_fmt=None"
)
print
(
"="
*
80
)
test_act_quant_comprehensive_benchmark
(
scale_fmt
=
None
)
print
(
"
\n
"
+
"="
*
80
)
print
(
"Running Comprehensive Benchmark with scale_fmt!=None"
)
print
(
"="
*
80
)
test_act_quant_comprehensive_benchmark
(
scale_fmt
=
"any"
)
else
:
print
(
"CUDA not available. Skipping tests."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment