Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7e6d5fc6
Unverified
Commit
7e6d5fc6
authored
Feb 12, 2025
by
Ke Bao
Committed by
GitHub
Feb 12, 2025
Browse files
Support Eagle cuda graph for Triton backend (#3500)
parent
cadd5dbe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
57 deletions
+142
-57
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+142
-55
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+0
-2
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
7e6d5fc6
...
@@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
skip_prefill
=
skip_prefill
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
if
kv_indptr_buf
is
None
:
...
@@ -48,6 +50,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -48,6 +50,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
kv_indptr
=
kv_indptr_buf
self
.
kv_indptr
=
kv_indptr_buf
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
torch
.
zeros
(
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
...
@@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
,
mask_indptr
,
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
max_context_len
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
else
:
self
.
cuda_graph_kv_indices
=
kv_indices_buf
if
not
self
.
skip_prefill
:
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
...
@@ -224,9 +235,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -224,9 +235,9 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
assert
encoder_lens
is
None
,
"Not supported"
assert
encoder_lens
is
None
,
"Not supported"
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
spec_info
is
None
,
"Not supported"
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
...
@@ -240,15 +251,55 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -240,15 +251,55 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
self
.
forward_metadata
=
(
attn_logits
=
self
.
cuda_graph_attn_logits
self
.
cuda_graph_attn_logits
,
max_extend_len
=
None
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
elif
forward_mode
.
is_target_verify
():
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
max_extend_len
=
self
.
num_draft_tokens
attn_logits
=
None
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph capture."
)
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
None
,
qo_indptr
,
None
,
custom_mask
,
None
,
mask_indptr
,
)
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -262,13 +313,13 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -262,13 +313,13 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
# NOTE: encoder_lens expected to be zeros or None
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
if
forward_mode
.
is_decode_or_idle
():
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
# Update kv_indptr, kv_indices
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
...
@@ -278,6 +329,41 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -278,6 +329,41 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph replay."
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
...
@@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend:
...
@@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend:
)
)
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
@@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend:
...
@@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend:
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
)
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
...
@@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend:
...
@@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
*
self
.
max_context_len
),
(
self
.
speculative_num_steps
,
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
...
...
test/srt/test_eagle_infer.py
View file @
7e6d5fc6
...
@@ -216,8 +216,6 @@ class TestEAGLEServerTriton(TestEAGLEServer):
...
@@ -216,8 +216,6 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"0.7"
,
"0.7"
,
"--attention-backend"
,
"--attention-backend"
,
"triton"
,
"triton"
,
# TODO: Support cuda graph
"--disable-cuda-graph"
,
],
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment