Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7e6d5fc6
Unverified
Commit
7e6d5fc6
authored
Feb 12, 2025
by
Ke Bao
Committed by
GitHub
Feb 12, 2025
Browse files
Support Eagle cuda graph for Triton backend (#3500)
parent
cadd5dbe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
57 deletions
+142
-57
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+142
-55
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+0
-2
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
7e6d5fc6
...
...
@@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
skip_prefill
=
skip_prefill
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
...
...
@@ -48,13 +50,15 @@ class TritonAttnBackend(AttentionBackend):
self
.
kv_indptr
=
kv_indptr_buf
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
mask_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int64
,
device
=
model_runner
.
device
)
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
mask_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int64
,
device
=
model_runner
.
device
)
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
...
...
@@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
max_context_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_kv_indices
=
kv_indices_buf
if
not
self
.
skip_prefill
:
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
...
...
@@ -224,31 +235,71 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInfo
],
):
assert
encoder_lens
is
None
,
"Not supported"
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
spec_info
is
None
,
"Not supported"
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
attn_logits
=
self
.
cuda_graph_attn_logits
max_extend_len
=
None
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
elif
forward_mode
.
is_target_verify
():
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
max_extend_len
=
self
.
num_draft_tokens
attn_logits
=
None
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph capture."
)
self
.
forward_metadata
=
(
self
.
cuda_graph_
attn_logits
,
None
,
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
,
None
,
None
,
None
,
qo_indptr
,
custom_mask
,
mask_indptr
,
)
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -262,22 +313,57 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInfo
],
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
if
forward_mode
.
is_decode_or_idle
():
# Update kv_indptr, kv_indices
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph replay."
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
...
...
@@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend:
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
...
@@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend:
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
def
call_fn
(
i
,
forward_batch
):
...
...
@@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
...
...
test/srt/test_eagle_infer.py
View file @
7e6d5fc6
...
...
@@ -216,8 +216,6 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"0.7"
,
"--attention-backend"
,
"triton"
,
# TODO: Support cuda graph
"--disable-cuda-graph"
,
],
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment