Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b9175cb
Unverified
Commit
1b9175cb
authored
Mar 27, 2025
by
Stefan He
Committed by
GitHub
Mar 27, 2025
Browse files
[FA3 Attn Backend] Remove Unnecessary Device Sync for FA3 (#4745)
Co-authored-by:
Yubo Wang
<
yubowang2019@gmail.com
>
parent
92bb49a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
10 deletions
+17
-10
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+16
-10
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
1b9175cb
...
@@ -29,11 +29,11 @@ class FlashAttentionMetadata:
...
@@ -29,11 +29,11 @@ class FlashAttentionMetadata:
cu_seqlens_q
:
torch
.
Tensor
=
None
cu_seqlens_q
:
torch
.
Tensor
=
None
cu_seqlens_k
:
torch
.
Tensor
=
None
cu_seqlens_k
:
torch
.
Tensor
=
None
max_seq_len_q
:
int
=
0
max_seq_len_k
:
int
=
0
max_seq_len_k
:
int
=
0
window_size
:
tuple
=
(
-
1
,
-
1
)
window_size
:
tuple
=
(
-
1
,
-
1
)
page_table
:
torch
.
Tensor
=
None
page_table
:
torch
.
Tensor
=
None
cache_seqlens_int32
:
torch
.
Tensor
=
None
cache_seqlens_int32
:
torch
.
Tensor
=
None
max_seq_len_q
:
int
=
0
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend):
# Create metadata based on forward mode
# Create metadata based on forward mode
metadata
=
FlashAttentionMetadata
()
metadata
=
FlashAttentionMetadata
()
extend_seq_lens
=
forward_batch
.
extend_seq_lens
# Get sequence information
# Get sequence information
seqlens_in_batch
=
forward_batch
.
seq_lens
seqlens_in_batch
=
forward_batch
.
seq_lens
# Precompute int32 version of sequence lengths
# Precompute int32 version of sequence lengths
...
@@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend):
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
)
else
:
else
:
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens
)
# Precompute cumulative sequence lengths
# Precompute cumulative sequence lengths
if
not
extend_no_prefix
:
if
any
(
forward_batch
.
extend_prefix_lens_cpu
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
torch
.
cumsum
(
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
else
:
else
:
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
metadata
.
max_seq_len_q
=
seqlens_in_batch
.
max
().
item
()
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
def
forward_extend
(
def
forward_extend
(
...
@@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
# """Initialize forward metadata for replaying CUDA graph."""
# """Initialize forward metadata for replaying CUDA graph."""
seqlens_in_batch
=
seq_lens
[:
bs
]
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
# For CPU operations
max_len
=
seq_lens_cpu
[:
bs
].
max
().
item
()
metadata
.
max_seq_len_k
=
max_len
# For GPU operations
seq_lens_in_batch
=
seq_lens
[:
bs
]
metadata
.
cache_seqlens_int32
=
seq_lens_in_batch
.
to
(
torch
.
int32
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
torch
.
cumsum
(
seq
_
lens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
)
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
seqlens_in_batch
.
max
().
item
()
# Only zero out the part out of max_len_k
# Only zero out the part out of max_len_k
metadata
.
page_table
[:,
metadata
.
max_seq_len_k
:].
fill_
(
0
)
metadata
.
page_table
[:,
metadata
.
max_seq_len_k
:].
fill_
(
0
)
# Then do the copy
# Then do the copy
metadata
.
page_table
[:,
:
metadata
.
max_seq_len_k
].
copy_
(
metadata
.
page_table
[:,
:
metadata
.
max_seq_len_k
].
copy_
(
self
.
req_to_token
[
req_pool_indices
[:
bs
],
:
metadata
.
max_seq_len_k
]
self
.
req_to_token
[
req_pool_indices
[:
bs
],
:
metadata
.
max_seq_len_k
]
)
)
self
.
forward_decode_metadata
=
metadata
self
.
forward_decode_metadata
=
metadata
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
1b9175cb
...
@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
(
if
(
global_server_args_dict
[
"enable_flashinfer_mla"
]
global_server_args_dict
[
"enable_flashinfer_mla"
]
or
global_server_args_dict
[
"enable_flashmla"
]
or
global_server_args_dict
[
"enable_flashmla"
]
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
):
):
decode_seq_lens
=
self
.
seq_lens
.
cpu
()
decode_seq_lens
=
self
.
seq_lens
.
cpu
()
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment