Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e62d60fe
Unverified
Commit
e62d60fe
authored
Mar 30, 2025
by
Baizhou Zhang
Committed by
GitHub
Mar 30, 2025
Browse files
[Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932)
parent
032f8faa
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
30 additions
and
35 deletions
+30
-35
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+1
-1
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+12
-13
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+8
-12
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+5
-5
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
e62d60fe
...
@@ -79,7 +79,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -79,7 +79,7 @@ class FlashAttentionBackend(AttentionBackend):
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
)
# Precompute maximum sequence length
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
seqlens_
in_batch
.
max
().
item
()
metadata
.
max_seq_len_k
=
forward_batch
.
seq
_
lens_
cpu
.
max
().
item
()
# Precompute page table
# Precompute page table
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
e62d60fe
...
@@ -797,7 +797,7 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -797,7 +797,7 @@ class FlashInferMLAMultiStepDraftBackend:
encoder_lens
=
None
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
decode_
seq_lens_cpu
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
e62d60fe
...
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
if
spec_info
is
None
:
max_seqlen_pad
=
triton
.
cdiv
(
max_seqlen_pad
=
triton
.
cdiv
(
forward_batch
.
decode_
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
)
block_kv_indices
=
torch
.
full
(
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
(
bs
,
max_seqlen_pad
),
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e62d60fe
...
@@ -1398,21 +1398,22 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1398,21 +1398,22 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
get_model_worker_batch
(
self
)
->
ModelWorkerBatch
:
def
get_model_worker_batch
(
self
)
->
ModelWorkerBatch
:
if
self
.
forward_mode
.
is_decode_or_idle
():
if
self
.
forward_mode
.
is_decode_or_idle
():
if
(
global_server_args_dict
[
"enable_flashinfer_mla"
]
or
global_server_args_dict
[
"enable_flashmla"
]
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
):
decode_seq_lens
=
self
.
seq_lens
.
cpu
()
else
:
decode_seq_lens
=
None
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
else
:
else
:
decode_seq_lens
=
None
extend_seq_lens
=
self
.
extend_lens
extend_seq_lens
=
self
.
extend_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
# Create seq_lens_cpu when needed
if
(
global_server_args_dict
[
"enable_flashinfer_mla"
]
or
global_server_args_dict
[
"enable_flashmla"
]
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
):
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
else
:
seq_lens_cpu
=
None
if
self
.
sampling_info
:
if
self
.
sampling_info
:
if
self
.
has_grammar
:
if
self
.
has_grammar
:
self
.
sampling_info
.
grammars
=
[
req
.
grammar
for
req
in
self
.
reqs
]
self
.
sampling_info
.
grammars
=
[
req
.
grammar
for
req
in
self
.
reqs
]
...
@@ -1435,7 +1436,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1435,7 +1436,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
decode_
seq_lens
=
decode_
seq_lens
,
seq_lens
_cpu
=
seq_lens
_cpu
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
...
@@ -1496,6 +1497,7 @@ class ModelWorkerBatch:
...
@@ -1496,6 +1497,7 @@ class ModelWorkerBatch:
req_pool_indices
:
torch
.
Tensor
req_pool_indices
:
torch
.
Tensor
# The sequence length
# The sequence length
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
# The indices of output tokens in the token_to_kv_pool_allocator
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc
:
torch
.
Tensor
out_cache_loc
:
torch
.
Tensor
...
@@ -1512,9 +1514,6 @@ class ModelWorkerBatch:
...
@@ -1512,9 +1514,6 @@ class ModelWorkerBatch:
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
can_run_dp_cuda_graph
:
bool
can_run_dp_cuda_graph
:
bool
# For decode
decode_seq_lens
:
Optional
[
torch
.
Tensor
]
# For extend
# For extend
extend_num_tokens
:
Optional
[
int
]
extend_num_tokens
:
Optional
[
int
]
extend_seq_lens
:
Optional
[
List
[
int
]]
extend_seq_lens
:
Optional
[
List
[
int
]]
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
e62d60fe
...
@@ -491,10 +491,10 @@ class CudaGraphRunner:
...
@@ -491,10 +491,10 @@ class CudaGraphRunner:
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
if
forward_batch
.
decode_
seq_lens_cpu
is
not
None
:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
decode_
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
e62d60fe
...
@@ -39,7 +39,6 @@ import triton
...
@@ -39,7 +39,6 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.utils
import
get_compiler_backend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
@@ -148,6 +147,9 @@ class ForwardBatch:
...
@@ -148,6 +147,9 @@ class ForwardBatch:
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
seq_lens_sum
:
int
# Optional seq_lens on cpu
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
# For logprob
# For logprob
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
Optional
[
List
[
int
]]
=
None
top_logprobs_nums
:
Optional
[
List
[
int
]]
=
None
...
@@ -162,9 +164,6 @@ class ForwardBatch:
...
@@ -162,9 +164,6 @@ class ForwardBatch:
# Position information
# Position information
positions
:
torch
.
Tensor
=
None
positions
:
torch
.
Tensor
=
None
# For decode
decode_seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
# For extend
# For extend
extend_num_tokens
:
Optional
[
int
]
=
None
extend_num_tokens
:
Optional
[
int
]
=
None
extend_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -293,12 +292,14 @@ class ForwardBatch:
...
@@ -293,12 +292,14 @@ class ForwardBatch:
):
):
ret
.
positions
=
ret
.
spec_info
.
positions
ret
.
positions
=
ret
.
spec_info
.
positions
# Get seq_lens_cpu if needed
if
ret
.
seq_lens_cpu
is
None
:
ret
.
seq_lens_cpu
=
batch
.
seq_lens_cpu
# Init position information
# Init position information
if
ret
.
forward_mode
.
is_decode
():
if
ret
.
forward_mode
.
is_decode
():
if
ret
.
positions
is
None
:
if
ret
.
positions
is
None
:
ret
.
positions
=
clamp_position
(
batch
.
seq_lens
)
ret
.
positions
=
torch
.
clamp
((
batch
.
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
if
ret
.
decode_seq_lens_cpu
is
None
:
ret
.
decode_seq_lens_cpu
=
batch
.
decode_seq_lens
else
:
else
:
ret
.
extend_seq_lens
=
torch
.
tensor
(
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
...
@@ -518,8 +519,3 @@ def compute_position_torch(
...
@@ -518,8 +519,3 @@ def compute_position_torch(
extend_start_loc
=
torch
.
zeros_like
(
extend_seq_lens
)
extend_start_loc
=
torch
.
zeros_like
(
extend_seq_lens
)
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
return
positions
.
to
(
torch
.
int64
),
extend_start_loc
return
positions
.
to
(
torch
.
int64
),
extend_start_loc
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
clamp_position
(
seq_lens
):
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
e62d60fe
...
@@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner:
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
# Special handle for seq_len_cpu used when flashinfer mla is used
# Special handle for seq_len_cpu used when flashinfer mla is used
if
(
forward_batch
.
decode_
seq_lens_cpu
is
not
None
)
and
(
bs
!=
raw_bs
):
if
(
forward_batch
.
seq_lens_cpu
is
not
None
)
and
(
bs
!=
raw_bs
):
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
decode_
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
forward_batch
.
decode_
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
forward_batch
,
bs
forward_batch
,
bs
...
@@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch
.
positions
=
self
.
positions
[:
raw_num_token
]
forward_batch
.
positions
=
self
.
positions
[:
raw_num_token
]
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
raw_bs
]
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
raw_bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
raw_bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
raw_bs
]
if
forward_batch
.
decode_
seq_lens_cpu
is
not
None
:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
forward_batch
.
decode_
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
raw_bs
]
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
raw_bs
]
return
out
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment