Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
79961afa
Unverified
Commit
79961afa
authored
May 07, 2025
by
Minglei Zhu
Committed by
GitHub
May 07, 2025
Browse files
optimize pad operations in fa3 to accelarate 100+us (#6077)
parent
cfca4e0e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
39 deletions
+17
-39
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+17
-39
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
79961afa
...
@@ -1525,12 +1525,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1525,12 +1525,9 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
+
(
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
self
.
speculative_step_id
+
1
)
)
metadata
.
cu_seqlens_k
.
copy_
(
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
)
)
)
...
@@ -1554,12 +1551,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1554,12 +1551,9 @@ class FlashAttentionBackend(AttentionBackend):
# metadata.max_seq_len_q = self.topk, already set in capture
# metadata.max_seq_len_q = self.topk, already set in capture
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
# metadata.cu_seqlens_q already set in capture
# metadata.cu_seqlens_q already set in capture
metadata
.
cu_seqlens_k
.
copy_
(
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
)
)
)
...
@@ -1616,13 +1610,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1616,13 +1610,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_k
=
(
metadata
.
max_seq_len_k
=
(
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
)
metadata
.
cu_seqlens_k
.
copy_
(
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
)
)
max_seq_pages
=
(
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
...
@@ -1641,13 +1630,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1641,13 +1630,8 @@ class FlashAttentionBackend(AttentionBackend):
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
# metadata.cu_seqlens_q already set in capture
# metadata.cu_seqlens_q already set in capture
metadata
.
cu_seqlens_k
.
copy_
(
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
)
)
page_table
=
self
.
req_to_token
[
page_table
=
self
.
req_to_token
[
req_pool_indices
,
:
metadata
.
max_seq_len_k
req_pool_indices
,
:
metadata
.
max_seq_len_k
...
@@ -1705,14 +1689,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1705,14 +1689,11 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand
.
cache_seqlens_int32
.
copy_
(
metadata_expand
.
cache_seqlens_int32
.
copy_
(
mask
.
sum
(
dim
=
1
).
to
(
torch
.
int32
)
mask
.
sum
(
dim
=
1
).
to
(
torch
.
int32
)
)
)
metadata_expand
.
cu_seqlens_k
.
copy_
(
metadata_expand
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
torch
.
cumsum
(
metadata_expand
.
cache_seqlens_int32
,
metadata_expand
.
cache_seqlens_int32
,
dim
=
0
,
dim
=
0
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
(
1
,
0
),
)
)
)
)
metadata_expand
.
max_seq_len_k
=
(
metadata_expand
.
max_seq_len_k
=
(
...
@@ -1723,11 +1704,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1723,11 +1704,8 @@ class FlashAttentionBackend(AttentionBackend):
# Only support encoder size 1 for now
# Only support encoder size 1 for now
metadata
.
encoder_max_seq_len_k
=
encoder_lens
[
0
]
metadata
.
encoder_max_seq_len_k
=
encoder_lens
[
0
]
metadata
.
encoder_lens_int32
.
copy_
(
encoder_lens
[:
1
])
metadata
.
encoder_lens_int32
.
copy_
(
encoder_lens
[:
1
])
metadata
.
encoder_cu_seqlens_k
.
copy_
(
metadata
.
encoder_cu_seqlens_k
[
1
:].
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
encoder_lens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
cumsum
(
metadata
.
encoder_lens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
)
)
metadata
.
encoder_page_table
[:,
:
metadata
.
encoder_max_seq_len_k
].
copy_
(
metadata
.
encoder_page_table
[:,
:
metadata
.
encoder_max_seq_len_k
].
copy_
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment