Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
384d85ba
Unverified
Commit
384d85ba
authored
Oct 24, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 24, 2024
Browse files
Re-introduce `get_cuda_graph_seq_len_fill_value` (#1783)
parent
60597219
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
19 additions
and
2 deletions
+19
-2
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+4
-0
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+3
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+3
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+3
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+6
-2
No files found.
python/sglang/srt/layers/attention/__init__.py
View file @
384d85ba
...
@@ -41,6 +41,10 @@ class AttentionBackend(ABC):
...
@@ -41,6 +41,10 @@ class AttentionBackend(ABC):
"""Init the metadata for a forward pass for replying a cuda graph."""
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
def
forward
(
def
forward
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
384d85ba
...
@@ -161,6 +161,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -161,6 +161,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
):
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
384d85ba
...
@@ -210,6 +210,9 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -210,6 +210,9 @@ class FlashInferAttnBackend(AttentionBackend):
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
)
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
def
forward_extend
(
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
):
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
384d85ba
...
@@ -108,6 +108,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -108,6 +108,9 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
):
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
384d85ba
...
@@ -134,7 +134,11 @@ class CudaGraphRunner:
...
@@ -134,7 +134,11 @@ class CudaGraphRunner:
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_bs
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_bs
)
self
.
seq_len_fill_value
=
1
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
self
.
encoder_len_fill_value
=
0
if
self
.
use_torch_compile
:
if
self
.
use_torch_compile
:
...
@@ -287,7 +291,7 @@ class CudaGraphRunner:
...
@@ -287,7 +291,7 @@ class CudaGraphRunner:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
# Common inputs
# Common inputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment