Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e0ce171d
"vscode:/vscode.git/clone" did not exist on "b4cb5a8e4a4a68ed6e838c38530dadae125941e1"
Unverified
Commit
e0ce171d
authored
Aug 20, 2025
by
Ke Bao
Committed by
GitHub
Aug 19, 2025
Browse files
Fix triton backend eagle illegal memory access (#9344)
parent
fe43e889
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
8 deletions
+9
-8
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+9
-8
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
e0ce171d
...
...
@@ -172,7 +172,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int
32
,
device
=
self
.
device
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
...
...
@@ -238,7 +238,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int
32
,
device
=
self
.
device
kv_indptr
[
-
1
],
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
...
...
@@ -289,6 +289,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
,
)
)
kv_indices
=
kv_indices
.
to
(
torch
.
int64
)
mask_indptr
=
None
# TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.accept_length_cpu)`.
...
...
@@ -304,7 +305,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
...
@@ -379,7 +380,7 @@ class TritonAttnBackend(AttentionBackend):
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_num_tokens
*
self
.
max_context_len
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
else
:
...
...
@@ -396,7 +397,7 @@ class TritonAttnBackend(AttentionBackend):
if
kv_indices_buf
is
None
:
self
.
cuda_graph_window_kv_indices
=
torch
.
zeros
(
(
max_num_tokens
*
self
.
sliding_window_size
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
else
:
...
...
@@ -888,7 +889,7 @@ class TritonMultiStepDraftBackend:
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
...
...
@@ -906,7 +907,7 @@ class TritonMultiStepDraftBackend:
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_num_tokens
*
self
.
max_context_len
),
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
...
...
@@ -1015,7 +1016,7 @@ def update_sliding_window_buffer(
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int
32
,
device
=
device
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int
64
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment