Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7e6d5fc6
"docs/source/vscode:/vscode.git/clone" did not exist on "308bd6f5b245929211f365396ca2007ac151b8e7"
Unverified
Commit
7e6d5fc6
authored
Feb 12, 2025
by
Ke Bao
Committed by
GitHub
Feb 12, 2025
Browse files
Support Eagle cuda graph for Triton backend (#3500)
parent
cadd5dbe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
57 deletions
+142
-57
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+142
-55
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+0
-2
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
7e6d5fc6
...
@@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
skip_prefill
=
skip_prefill
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
if
kv_indptr_buf
is
None
:
...
@@ -48,13 +50,15 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -48,13 +50,15 @@ class TritonAttnBackend(AttentionBackend):
self
.
kv_indptr
=
kv_indptr_buf
self
.
kv_indptr
=
kv_indptr_buf
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
mask_indptr
=
torch
.
zeros
(
if
not
self
.
skip_prefill
:
(
max_bs
+
1
,),
dtype
=
torch
.
int64
,
device
=
model_runner
.
device
self
.
qo_indptr
=
torch
.
zeros
(
)
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
mask_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int64
,
device
=
model_runner
.
device
)
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
...
@@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
,
mask_indptr
,
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
max_context_len
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
if
kv_indices_buf
is
None
:
(
max_bs
*
self
.
max_context_len
),
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
dtype
=
torch
.
int32
,
(
max_bs
*
self
.
max_context_len
),
device
=
self
.
device
,
dtype
=
torch
.
int32
,
)
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_kv_indices
=
kv_indices_buf
if
not
self
.
skip_prefill
:
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
...
@@ -224,31 +235,71 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -224,31 +235,71 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
assert
encoder_lens
is
None
,
"Not supported"
assert
encoder_lens
is
None
,
"Not supported"
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
spec_info
is
None
,
"Not supported"
kv_indptr
=
self
.
kv_indptr
if
forward_mode
.
is_decode_or_idle
():
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
if
spec_info
is
None
:
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
kv_indptr
=
kv_indptr
[:
bs
+
1
]
self
.
req_to_token
,
kv_indices
=
self
.
cuda_graph_kv_indices
req_pool_indices
,
create_flashinfer_kv_indices_triton
[(
bs
,)](
seq_lens
,
self
.
req_to_token
,
kv_indptr
,
req_pool_indices
,
None
,
seq_lens
,
kv_indices
,
kv_indptr
,
self
.
req_to_token
.
stride
(
0
),
None
,
)
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
attn_logits
=
self
.
cuda_graph_attn_logits
max_extend_len
=
None
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
elif
forward_mode
.
is_target_verify
():
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
max_extend_len
=
self
.
num_draft_tokens
attn_logits
=
None
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph capture."
)
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_
attn_logits
,
attn_logits
,
None
,
max_extend_len
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
None
,
qo_indptr
,
None
,
custom_mask
,
None
,
mask_indptr
,
)
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -262,22 +313,57 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -262,22 +313,57 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
# NOTE: encoder_lens expected to be zeros or None
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
if
forward_mode
.
is_decode_or_idle
():
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
# Update kv_indptr, kv_indices
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
if
spec_info
is
None
:
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
kv_indptr
=
kv_indptr
[:
bs
+
1
]
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens
[:
bs
],
kv_indptr
,
kv_indptr
,
None
,
None
,
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph replay."
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
...
@@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend:
...
@@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend:
)
)
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
@@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend:
...
@@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend:
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
)
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
...
@@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend:
...
@@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
*
self
.
max_context_len
),
(
self
.
speculative_num_steps
,
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
...
...
test/srt/test_eagle_infer.py
View file @
7e6d5fc6
...
@@ -216,8 +216,6 @@ class TestEAGLEServerTriton(TestEAGLEServer):
...
@@ -216,8 +216,6 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"0.7"
,
"0.7"
,
"--attention-backend"
,
"--attention-backend"
,
"triton"
,
"triton"
,
# TODO: Support cuda graph
"--disable-cuda-graph"
,
],
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment