Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e0ce171d
Unverified
Commit
e0ce171d
authored
Aug 20, 2025
by
Ke Bao
Committed by
GitHub
Aug 19, 2025
Browse files
Fix triton backend eagle illegal memory access (#9344)
parent
fe43e889
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
8 deletions
+9
-8
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+9
-8
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
e0ce171d
...
...
@@ -172,7 +172,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int
32
,
device
=
self
.
device
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
...
...
@@ -238,7 +238,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int
32
,
device
=
self
.
device
kv_indptr
[
-
1
],
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
...
...
@@ -289,6 +289,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
,
)
)
kv_indices
=
kv_indices
.
to
(
torch
.
int64
)
mask_indptr
=
None
# TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.accept_length_cpu)`.
...
...
@@ -304,7 +305,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
...
@@ -379,7 +380,7 @@ class TritonAttnBackend(AttentionBackend):
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_num_tokens
*
self
.
max_context_len
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
else
:
...
...
@@ -396,7 +397,7 @@ class TritonAttnBackend(AttentionBackend):
if
kv_indices_buf
is
None
:
self
.
cuda_graph_window_kv_indices
=
torch
.
zeros
(
(
max_num_tokens
*
self
.
sliding_window_size
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
else
:
...
...
@@ -888,7 +889,7 @@ class TritonMultiStepDraftBackend:
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
...
...
@@ -906,7 +907,7 @@ class TritonMultiStepDraftBackend:
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_num_tokens
*
self
.
max_context_len
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
...
...
@@ -1015,7 +1016,7 @@ def update_sliding_window_buffer(
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int
32
,
device
=
device
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int
64
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment