Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
22630ca2
"tests/vscode:/vscode.git/clone" did not exist on "781775ea56160a6dea3d53fd5005d0d7fca5f10a"
Unverified
Commit
22630ca2
authored
May 30, 2025
by
Jianan Ji
Committed by
GitHub
May 30, 2025
Browse files
Support sliding window in triton backend (#6509)
parent
d279d499
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
350 additions
and
13 deletions
+350
-13
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+198
-5
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+12
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-4
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+7
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_triton_sliding_window.py
test/srt/test_triton_sliding_window.py
+132
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
22630ca2
...
@@ -72,6 +72,65 @@ def get_num_kv_splits_triton(
...
@@ -72,6 +72,65 @@ def get_num_kv_splits_triton(
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
def
update_sliding_window_buffer
(
window_kv_indptr
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
device
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
def
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_lens
@
dataclass
@
dataclass
class
ForwardMetadata
:
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
attn_logits
:
torch
.
Tensor
...
@@ -83,6 +142,10 @@ class ForwardMetadata:
...
@@ -83,6 +142,10 @@ class ForwardMetadata:
qo_indptr
:
torch
.
Tensor
qo_indptr
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
mask_indptr
:
torch
.
Tensor
mask_indptr
:
torch
.
Tensor
# Sliding window
window_kv_indptr
:
torch
.
Tensor
window_kv_indices
:
torch
.
Tensor
window_num_kv_splits
:
torch
.
Tensor
class
TritonAttnBackend
(
AttentionBackend
):
class
TritonAttnBackend
(
AttentionBackend
):
...
@@ -109,6 +172,13 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -109,6 +172,13 @@ class TritonAttnBackend(AttentionBackend):
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
assert
not
(
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
model_config
.
is_encoder_decoder
),
"Sliding window and cross attention are not supported together"
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
if
kv_indptr_buf
is
None
:
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
...
@@ -116,6 +186,18 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -116,6 +186,18 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
self
.
kv_indptr
=
kv_indptr_buf
self
.
kv_indptr
=
kv_indptr_buf
# If sliding window is enabled, we might need two sets of buffers
# because of interleaved attention types (e.g. for Gemma3)
self
.
window_kv_indptr
=
None
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
if
kv_indptr_buf
is
None
:
self
.
window_kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
else
:
# When provided a buffer, create a clone for the second buffer
self
.
window_kv_indptr
=
torch
.
zeros_like
(
kv_indptr_buf
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
...
@@ -191,6 +273,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -191,6 +273,9 @@ class TritonAttnBackend(AttentionBackend):
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
window_kv_indptr
=
self
.
window_kv_indptr
window_kv_indices
=
None
window_num_kv_splits
=
None
spec_info
=
forward_batch
.
spec_info
spec_info
=
forward_batch
.
spec_info
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
...
@@ -209,6 +294,26 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -209,6 +294,26 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
# Sliding window
if
(
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
):
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
=
(
update_sliding_window_buffer
(
self
.
window_kv_indptr
,
self
.
req_to_token
,
self
.
sliding_window_size
,
forward_batch
.
seq_lens
,
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
)
)
window_num_kv_splits
=
torch
.
empty
(
(
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
get_num_kv_splits
(
window_num_kv_splits
,
window_kv_lens
)
else
:
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
...
@@ -224,7 +329,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -224,7 +329,6 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
)
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
)
qo_indptr
=
None
qo_indptr
=
None
...
@@ -232,6 +336,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -232,6 +336,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
mask_indptr
=
None
max_extend_len
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# TODO: Support sliding window in spec inference
bs
=
len
(
forward_batch
.
req_pool_indices
)
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
qo_indptr
=
torch
.
arange
(
0
,
0
,
...
@@ -303,6 +408,17 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -303,6 +408,17 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
# Sliding window
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indptr
,
window_kv_indices
,
_
=
update_sliding_window_buffer
(
self
.
window_kv_indptr
,
self
.
req_to_token
,
self
.
sliding_window_size
,
forward_batch
.
extend_prefix_lens
,
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
)
qo_indptr
=
self
.
qo_indptr
qo_indptr
=
self
.
qo_indptr
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
...
@@ -324,6 +440,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -324,6 +440,9 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
,
qo_indptr
,
custom_mask
,
custom_mask
,
mask_indptr
,
mask_indptr
,
window_kv_indptr
,
window_kv_indices
,
window_num_kv_splits
,
)
)
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
...
@@ -358,6 +477,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -358,6 +477,20 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
if
kv_indices_buf
is
None
:
self
.
cuda_graph_window_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
sliding_window_size
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_window_kv_indices
=
torch
.
zeros_like
(
kv_indices_buf
)
self
.
cuda_graph_window_num_kv_splits
=
torch
.
full
(
(
max_bs
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
...
@@ -369,6 +502,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -369,6 +502,9 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
):
assert
encoder_lens
is
None
,
"Not supported"
assert
encoder_lens
is
None
,
"Not supported"
window_kv_indptr
=
self
.
window_kv_indptr
window_kv_indices
=
None
window_num_kv_splits
=
None
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
if
spec_info
is
None
:
...
@@ -385,6 +521,21 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -385,6 +521,21 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
if
(
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
):
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
_
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
[:
bs
],
req_pool_indices
,
bs
,
)
else
:
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
...
@@ -468,6 +619,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -468,6 +619,9 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
,
qo_indptr
,
custom_mask
,
custom_mask
,
mask_indptr
,
mask_indptr
,
window_kv_indptr
,
window_kv_indices
,
window_num_kv_splits
,
)
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -500,11 +654,31 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -500,11 +654,31 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
num_token
=
bs
num_token
=
bs
if
(
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
):
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
[:
bs
],
req_pool_indices
[:
bs
],
bs
,
)
self
.
get_num_kv_splits
(
window_num_kv_splits
[:
num_token
],
window_kv_lens
[:
bs
]
)
else
:
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
num_token
=
spec_info
.
kv_indptr
.
shape
[
0
]
-
1
num_token
=
spec_info
.
kv_indptr
.
shape
[
0
]
-
1
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
...
@@ -582,6 +756,17 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -582,6 +756,17 @@ class TritonAttnBackend(AttentionBackend):
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
causal
=
False
if
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
:
sliding_window_size
=
(
layer
.
sliding_window_size
)
# Needed for sliding window mask
kv_indptr
=
self
.
forward_metadata
.
window_kv_indptr
kv_indices
=
self
.
forward_metadata
.
window_kv_indices
else
:
sliding_window_size
=
-
1
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -590,14 +775,15 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -590,14 +775,15 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
self
.
forward_metadata
.
qo_indptr
,
self
.
forward_metadata
.
qo_indptr
,
self
.
forward_metadata
.
kv_indptr
,
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
kv_indices
,
self
.
forward_metadata
.
custom_mask
,
self
.
forward_metadata
.
custom_mask
,
causal
,
causal
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
max_extend_len
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
sliding_window_size
,
)
)
return
o
return
o
...
@@ -625,13 +811,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -625,13 +811,20 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
if
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
:
kv_indptr
=
self
.
forward_metadata
.
window_kv_indptr
kv_indices
=
self
.
forward_metadata
.
window_kv_indices
else
:
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
self
.
decode_attention_fwd
(
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
self
.
forward_metadata
.
kv_indptr
,
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
kv_indices
,
self
.
forward_metadata
.
attn_logits
,
self
.
forward_metadata
.
attn_logits
,
self
.
forward_metadata
.
attn_lse
,
self
.
forward_metadata
.
attn_lse
,
self
.
forward_metadata
.
num_kv_splits
,
self
.
forward_metadata
.
num_kv_splits
,
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
22630ca2
...
@@ -65,6 +65,7 @@ def _fwd_kernel(
...
@@ -65,6 +65,7 @@ def _fwd_kernel(
stride_buf_kh
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vbs
,
stride_buf_vh
,
stride_buf_vh
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
@@ -163,6 +164,7 @@ def _fwd_kernel(
...
@@ -163,6 +164,7 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
custom_mask
=
tl
.
load
(
mask_ptr
mask_ptr
...
@@ -173,10 +175,14 @@ def _fwd_kernel(
...
@@ -173,10 +175,14 @@ def _fwd_kernel(
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
other
=
0
,
)
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
custom_mask
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
if
SLIDING_WINDOW_SIZE
>
0
:
else
:
# Add mask where q_id <= kv_id + sliding_window_size
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -314,6 +320,7 @@ def extend_attention_fwd(
...
@@ -314,6 +320,7 @@ def extend_attention_fwd(
sm_scale
=
None
,
sm_scale
=
None
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
skip_prefix_custom_mask
=
True
,
sliding_window_size
=-
1
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -412,6 +419,7 @@ def extend_attention_fwd(
...
@@ -412,6 +419,7 @@ def extend_attention_fwd(
k_buffer
.
stride
(
1
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
v_buffer
.
stride
(
1
),
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
22630ca2
...
@@ -1025,10 +1025,6 @@ class ModelRunner:
...
@@ -1025,10 +1025,6 @@ class ModelRunner:
return
AiterAttnBackend
(
self
)
return
AiterAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
"Please use `--attention-backend flashinfer`."
...
...
python/sglang/srt/models/gemma3_causal.py
View file @
22630ca2
...
@@ -277,6 +277,13 @@ class Gemma3Attention(nn.Module):
...
@@ -277,6 +277,13 @@ class Gemma3Attention(nn.Module):
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
=
forward_batch
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
=
forward_batch
)
# Compatible with triton backend which returns [1, s, h, head_dim]
if
attn_output
.
dim
()
==
4
and
attn_output
.
shape
[
0
]
==
1
:
attn_output
=
attn_output
.
squeeze
(
0
)
attn_output
=
attn_output
.
flatten
(
-
2
,
-
1
)
# [s, h * head_dim]
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
test/srt/run_suite.py
View file @
22630ca2
...
@@ -78,6 +78,7 @@ suites = {
...
@@ -78,6 +78,7 @@ suites = {
TestFile
(
"test_triton_attention_kernels.py"
,
4
),
TestFile
(
"test_triton_attention_kernels.py"
,
4
),
TestFile
(
"test_triton_attention_backend.py"
,
134
),
TestFile
(
"test_triton_attention_backend.py"
,
134
),
TestFile
(
"test_triton_moe_channel_fp8_kernel.py"
,
25
),
TestFile
(
"test_triton_moe_channel_fp8_kernel.py"
,
25
),
TestFile
(
"test_triton_sliding_window.py"
,
250
),
TestFile
(
"test_update_weights_from_disk.py"
,
114
),
TestFile
(
"test_update_weights_from_disk.py"
,
114
),
TestFile
(
"test_update_weights_from_tensor.py"
,
48
),
TestFile
(
"test_update_weights_from_tensor.py"
,
48
),
TestFile
(
"test_vertex_endpoint.py"
,
31
),
TestFile
(
"test_vertex_endpoint.py"
,
31
),
...
...
test/srt/test_triton_sliding_window.py
0 → 100644
View file @
22630ca2
import
time
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestSlidingWindowAttentionTriton
(
CustomTestCase
):
"""Test sliding window attention functionality with triton backend."""
@
classmethod
def
setUpClass
(
cls
):
"""Set up the test server with Gemma3 model and triton backend."""
# Gemma3 model supports sliding window attention
cls
.
model
=
"google/gemma-3-4b-it"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
common_args
=
[
"--trust-remote-code"
,
"--attention-backend"
,
"triton"
,
"--context-length"
,
"8192"
,
"--random-seed"
,
"42"
,
]
cls
.
short_context_prompt
=
"The capital of France is"
# Test prompt longer than window size
cls
.
long_context_prompt
=
(
"""
Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was:
"""
*
100
)
cls
.
long_context_prompt
+=
"
\n
Now, summarize the story in one sentence:"
@
classmethod
def
tearDownClass
(
cls
):
pass
def
_test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
print
(
f
"MMLU metrics with sliding window:
{
metrics
}
"
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.64
)
def
_test_short_context_generation
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
self
.
short_context_prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
256
,
},
},
)
self
.
assertEqual
(
response
.
status_code
,
200
)
result
=
response
.
json
()
self
.
assertIn
(
"paris"
,
result
[
"text"
].
lower
())
print
(
f
"Short context generation result:
{
result
[
'text'
]
}
"
)
def
_test_long_context_generation
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
self
.
long_context_prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
256
,
},
},
)
self
.
assertEqual
(
response
.
status_code
,
200
)
result
=
response
.
json
()
self
.
assertGreater
(
len
(
result
[
"text"
].
strip
()),
0
)
print
(
f
"Long context generation result:
{
result
[
'text'
][:
100
]
}
..."
)
def
test_no_cuda_graph
(
self
):
self
.
no_cuda_graph_process
=
popen_launch_server
(
self
.
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
self
.
common_args
+
[
"--disable-cuda-graph"
],
)
self
.
_test_short_context_generation
()
self
.
_test_long_context_generation
()
self
.
_test_mmlu
()
kill_process_tree
(
self
.
no_cuda_graph_process
.
pid
)
time
.
sleep
(
5
)
def
test_cuda_graph
(
self
):
self
.
cuda_graph_process
=
popen_launch_server
(
self
.
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
self
.
common_args
,
)
self
.
_test_short_context_generation
()
self
.
_test_long_context_generation
()
self
.
_test_mmlu
()
kill_process_tree
(
self
.
cuda_graph_process
.
pid
)
time
.
sleep
(
5
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment