Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
93470a14
Unverified
Commit
93470a14
authored
Apr 07, 2025
by
Stefan He
Committed by
GitHub
Apr 07, 2025
Browse files
Refactor and Optimize FA3 Code (#5090)
Co-authored-by:
Qingquan Song
<
ustcsqq@gmail.com
>
parent
db452760
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
142 deletions
+94
-142
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+94
-142
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
93470a14
from
__future__
import
annotations
from
__future__
import
annotations
import
numpy
as
np
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
"""
Support different attention backends.
Now there are three backends: FlashInfer, Triton and FlashAttention.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -30,22 +22,25 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache
...
@@ -30,22 +22,25 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache
@
dataclass
@
dataclass
class
FlashAttentionMetadata
:
class
FlashAttentionMetadata
:
"""Metadata to be init once in the model forward pass,
"""Metadata to be init once in the model forward pass,
each layer's forward pass can reuse the metadata.
"""
each layer's forward pass can reuse the metadata.
# Cumulative sequence lengths for query
For each init metadata function, we will try set up them in below order
cu_seqlens_q
:
torch
.
Tensor
=
None
"""
# Cumulative sequence lengths for key
cu_seqlens_k
:
torch
.
Tensor
=
None
# Sequence lengths for the forward batch
cache_seqlens_int32
:
torch
.
Tensor
=
None
# Maximum sequence length for query
# Maximum sequence length for query
max_seq_len_q
:
int
=
0
max_seq_len_q
:
int
=
0
# Maximum sequence length for key
# Maximum sequence length for key
max_seq_len_k
:
int
=
0
max_seq_len_k
:
int
=
0
# Cumulative sequence lengths for query
cu_seqlens_q
:
torch
.
Tensor
=
None
# Cumulative sequence lengths for key
cu_seqlens_k
:
torch
.
Tensor
=
None
# Window size (typically used by Gemma)
# Window size (typically used by Gemma)
window_size
:
tuple
=
(
-
1
,
-
1
)
window_size
:
tuple
=
(
-
1
,
-
1
)
# Page table, the index of KV Cache Tables/Blocks
# Page table, the index of KV Cache Tables/Blocks
page_table
:
torch
.
Tensor
=
None
page_table
:
torch
.
Tensor
=
None
# Sequence lengths for the forward batch
cache_seqlens_int32
:
torch
.
Tensor
=
None
@
dataclass
@
dataclass
class
LocalAttentionMetadata
:
class
LocalAttentionMetadata
:
...
@@ -270,9 +265,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -270,9 +265,9 @@ class FlashAttentionBackend(AttentionBackend):
self
,
self
,
model_runner
:
ModelRunner
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
skip_prefill
:
bool
=
False
,
speculative_step_id
=
0
,
topk
=
0
,
topk
=
0
,
speculative_num_steps
=
0
,
speculative_num_steps
=
0
,
step_id
=
0
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -293,14 +288,12 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -293,14 +288,12 @@ class FlashAttentionBackend(AttentionBackend):
)
and
(
not
global_server_args_dict
[
"disable_mla"
])
)
and
(
not
global_server_args_dict
[
"disable_mla"
])
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
self
.
topk
=
topk
assert
(
topk
<=
1
),
"topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
self
.
topk
=
1
self
.
step_id
=
step_id
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_draft_tokens
=
(
model_runner
.
server_args
.
speculative_num_draft_tokens
)
self
.
speculative_step_id
=
speculative_step_id
# Local attention settings
# Local attention settings
self
.
attention_chunk_size
=
(
self
.
attention_chunk_size
=
(
...
@@ -310,71 +303,59 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -310,71 +303,59 @@ class FlashAttentionBackend(AttentionBackend):
)
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize forward metadata
to cache repetitive calculations
."""
"""Initialize forward metadata
hence all layers in the forward pass can reuse it
."""
metadata
=
FlashAttentionMetadata
()
metadata
=
FlashAttentionMetadata
()
seqlens_in_batch
=
forward_batch
.
seq_lens
seqlens_in_batch
=
forward_batch
.
seq_lens
batch_size
=
len
(
seqlens_in_batch
)
batch_size
=
len
(
seqlens_in_batch
)
device
=
seqlens_in_batch
.
device
device
=
seqlens_in_batch
.
device
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
# Skip Prefill or Draft Decode
# Draft Decode
# Note: Draft Decode will be ran on the Draft Worker
if
forward_batch
.
spec_info
is
not
None
:
if
forward_batch
.
spec_info
is
not
None
:
metadata
.
cache_seqlens_int32
=
(
seqlens_in_batch
+
(
self
.
speculative_step_id
+
1
)
).
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
)
seq_lens_with_decode
=
seqlens_in_batch
+
(
self
.
step_id
+
1
)
metadata
.
cache_seqlens_int32
=
seq_lens_with_decode
.
to
(
torch
.
int32
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
),
(
1
,
0
),
(
1
,
0
),
)
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
(
self
.
step_id
+
1
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
cache_loc
=
forward_batch
.
out_cache_loc
.
view
(
else
:
self
.
speculative_num_steps
,
-
1
# Normal Decode
).
T
for
idx
,
single_seq_len
in
enumerate
(
seq_lens_with_decode
):
real_bsz_start_idx
=
idx
real_bsz_end_idx
=
idx
+
1
metadata
.
page_table
[
real_bsz_start_idx
:
real_bsz_end_idx
,
(
single_seq_len
-
(
self
.
step_id
+
1
))
:
single_seq_len
,
]
=
cache_loc
[
real_bsz_start_idx
:
real_bsz_end_idx
,
:
(
self
.
step_id
+
1
)
]
else
:
# Normal Decode without Spec Decoding
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# Note: Target Verify will be ran on the Target Worker
draft_token_num
=
forward_batch
.
spec_info
.
draft_token_num
metadata
.
cache_seqlens_int32
=
(
metadata
.
cache_seqlens_int32
=
(
forward_batch
.
seq_lens
+
draft_token
_num
forward_batch
.
seq_lens
+
self
.
speculative_num_
draft_token
s
).
to
(
torch
.
int32
)
).
to
(
torch
.
int32
)
metadata
.
max_seq_len_q
=
draft_token
_num
metadata
.
max_seq_len_q
=
self
.
speculative_num_
draft_token
s
metadata
.
max_seq_len_k
=
(
metadata
.
max_seq_len_k
=
(
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
draft_token_num
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
0
,
batch_size
*
draft_token
_num
+
1
,
batch_size
*
self
.
speculative_num_
draft_token
s
+
1
,
draft_token
_num
,
self
.
speculative_num_
draft_token
s
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
...
@@ -387,31 +368,27 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -387,31 +368,27 @@ class FlashAttentionBackend(AttentionBackend):
]
]
elif
forward_batch
.
forward_mode
.
is_extend_or_draft_extend
():
elif
forward_batch
.
forward_mode
.
is_extend_or_draft_extend
():
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
)
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
# Precompute page table
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
# Precompute cumulative sequence lengths
if
(
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
torch
.
cumsum
(
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
else
:
else
:
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
# Setup local attention if enabled
if
(
if
(
...
@@ -458,7 +435,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -458,7 +435,7 @@ class FlashAttentionBackend(AttentionBackend):
)
)
metadata
.
local_attn_metadata
=
local_metadata
metadata
.
local_attn_metadata
=
local_metadata
#
Precompute strided indices
#
Convert the page table to a strided format which is needed by FA3 API
if
self
.
page_size
>
1
:
if
self
.
page_size
>
1
:
self
.
strided_indices
=
torch
.
arange
(
self
.
strided_indices
=
torch
.
arange
(
0
,
metadata
.
page_table
.
shape
[
1
],
self
.
page_size
,
device
=
self
.
device
0
,
metadata
.
page_table
.
shape
[
1
],
self
.
page_size
,
device
=
self
.
device
...
@@ -498,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -498,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend):
v
,
v
,
)
)
# Use precomputed metadata
# Use precomputed metadata
across all layers
metadata
=
self
.
forward_metadata
metadata
=
self
.
forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change)
# Calculate window size (can be moved to metadata if layer properties don't change)
...
@@ -606,8 +583,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -606,8 +583,6 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention using precomputed metadata."""
# Save KV cache if needed
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
...
@@ -628,7 +603,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -628,7 +603,7 @@ class FlashAttentionBackend(AttentionBackend):
v
,
v
,
)
)
# Use precomputed metadata
# Use precomputed metadata
across all layers
metadata
=
self
.
forward_metadata
metadata
=
self
.
forward_metadata
# Calculate window size (can be moved to metadata if layer properties don't change)
# Calculate window size (can be moved to metadata if layer properties don't change)
...
@@ -639,12 +614,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -639,12 +614,9 @@ class FlashAttentionBackend(AttentionBackend):
if
layer
.
sliding_window_size
is
not
None
if
layer
.
sliding_window_size
is
not
None
else
(
-
1
,
-
1
)
else
(
-
1
,
-
1
)
)
)
page_table
=
metadata
.
page_table
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
# Get KV cache
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
=
key_cache
.
view
(
key_cache
=
key_cache
.
view
(
...
@@ -654,13 +626,12 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -654,13 +626,12 @@ class FlashAttentionBackend(AttentionBackend):
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
)
# Pre-reshape query tensor
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
o
=
flash_attn_with_kvcache
(
o
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
q
=
q_reshaped
,
k_cache
=
key_cache
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
page_table
=
metadata
.
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
...
@@ -696,7 +667,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -696,7 +667,7 @@ class FlashAttentionBackend(AttentionBackend):
k_cache
=
k_rope_cache
,
k_cache
=
k_rope_cache
,
v_cache
=
c_kv_cache
,
v_cache
=
c_kv_cache
,
qv
=
q_nope
,
qv
=
q_nope
,
page_table
=
page_table
,
page_table
=
metadata
.
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
...
@@ -719,7 +690,13 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -719,7 +690,13 @@ class FlashAttentionBackend(AttentionBackend):
to avoid memory allocations.
to avoid memory allocations.
"""
"""
self
.
decode_cuda_graph_metadata
=
{
self
.
decode_cuda_graph_metadata
=
{
# Page table for token mapping (batch_size, max_context_len)
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
"page_table"
:
torch
.
zeros
(
max_bs
,
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
...
@@ -735,30 +712,22 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -735,30 +712,22 @@ class FlashAttentionBackend(AttentionBackend):
"strided_indices"
:
torch
.
arange
(
"strided_indices"
:
torch
.
arange
(
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
),
}
self
.
target_verify_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
arange
(
"cu_seqlens_q"
:
torch
.
zeros
(
0
,
max_bs
+
1
28
,
dtype
=
torch
.
int32
,
device
=
self
.
device
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
),
"cu_seqlens_k"
:
torch
.
zeros
(
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
28
,
dtype
=
torch
.
int32
,
device
=
self
.
device
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
),
}
self
.
target_verify_metadata
=
{
"page_table"
:
torch
.
zeros
(
"page_table"
:
torch
.
zeros
(
max_bs
,
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
),
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
zeros
(
max_bs
+
128
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
128
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"max_seqlen_q"
:
0
,
"strided_indices"
:
torch
.
arange
(
"strided_indices"
:
torch
.
arange
(
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
),
...
@@ -780,24 +749,21 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -780,24 +749,21 @@ class FlashAttentionBackend(AttentionBackend):
if
forward_mode
.
is_decode
():
if
forward_mode
.
is_decode
():
if
spec_info
is
not
None
:
if
spec_info
is
not
None
:
# Draft Decode
# Draft Decode
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
metadata
.
cache_seqlens_int32
=
self
.
decode_cuda_graph_metadata
[
metadata
.
cache_seqlens_int32
=
self
.
decode_cuda_graph_metadata
[
"cache_seqlens"
"cache_seqlens"
][:
bs
]
][:
bs
]
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
][
metadata
.
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
][
:
bs
+
1
:
bs
+
1
]
]
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
),
(
1
,
0
),
(
1
,
0
),
)
)
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
+
(
self
.
step_id
+
1
)
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table_draft_decode"
"page_table_draft_decode"
][
req_pool_indices
,
:]
][
req_pool_indices
,
:]
...
@@ -822,37 +788,30 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -822,37 +788,30 @@ class FlashAttentionBackend(AttentionBackend):
)
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
draft_token_num
=
spec_info
.
draft_token_num
metadata
.
cache_seqlens_int32
=
self
.
target_verify_metadata
[
"cache_seqlens"
][
metadata
.
cache_seqlens_int32
=
self
.
target_verify_metadata
[
"cache_seqlens"
][
:
bs
:
bs
]
]
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
draft_token
_num
).
to
(
torch
.
int32
)
(
seq_lens
+
self
.
speculative_num_
draft_token
s
).
to
(
torch
.
int32
)
)
)
metadata
.
max_seq_len_q
=
draft_token_num
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
+
draft_token_num
metadata
.
max_seq_len_k
=
(
seq_lens
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
metadata
.
cu_seqlens_q
=
self
.
target_verify_metadata
[
"cu_seqlens_q"
][
metadata
.
cu_seqlens_q
=
torch
.
arange
(
torch
.
arange
(
0
,
0
,
bs
*
draft_token
_num
+
1
,
bs
*
self
.
speculative_num_
draft_token
s
+
1
,
draft_token
_num
,
self
.
speculative_num_
draft_token
s
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
metadata
.
cu_seqlens_k
=
self
.
target_verify_metadata
[
"cu_seqlens_k"
][
:
(
bs
+
1
)
]
]
cu_k
=
self
.
target_verify_metadata
[
"cu_seqlens_k"
][:
(
bs
+
1
)]
cu_k
.
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
)
metadata
.
cu_seqlens_k
=
cu_k
metadata
.
page_table
=
self
.
target_verify_metadata
[
"page_table"
][
metadata
.
page_table
=
self
.
target_verify_metadata
[
"page_table"
][
req_pool_indices
,
:
req_pool_indices
,
:
]
]
...
@@ -874,24 +833,21 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -874,24 +833,21 @@ class FlashAttentionBackend(AttentionBackend):
out_cache_loc
:
torch
.
Tensor
=
None
,
out_cache_loc
:
torch
.
Tensor
=
None
,
):
):
# """Initialize forward metadata for replaying CUDA graph."""
# """Initialize forward metadata for replaying CUDA graph."""
device
=
seq_lens
.
device
seq_lens
=
seq_lens
[:
bs
]
seq_lens
=
seq_lens
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
if
forward_mode
.
is_decode
():
if
forward_mode
.
is_decode
():
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
if
spec_info
is
not
None
:
if
spec_info
is
not
None
:
# Draft Decode
# Draft Decode
max_len
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
max_len
+
(
self
.
step_id
+
1
)
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
(
self
.
step_id
+
1
)).
to
(
torch
.
int32
)
(
seq_lens
+
(
self
.
speculative_
step_id
+
1
)).
to
(
torch
.
int32
)
)
)
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
+
(
self
.
step_id
+
1
)
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_k
.
copy_
(
metadata
.
cu_seqlens_k
.
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
torch
.
cumsum
(
...
@@ -929,22 +885,13 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -929,22 +885,13 @@ class FlashAttentionBackend(AttentionBackend):
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
=
self
.
target_verify_metadata
[
bs
]
draft_token_num
=
spec_info
.
draft_token_num
metadata
.
cu_seqlens_q
.
copy_
(
torch
.
arange
(
0
,
bs
*
draft_token_num
+
1
,
draft_token_num
,
dtype
=
torch
.
int32
,
device
=
device
,
)
)
metadata
.
cache_seqlens_int32
.
copy_
(
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
draft_token
_num
).
to
(
torch
.
int32
)
(
seq_lens
+
self
.
speculative_num_
draft_token
s
).
to
(
torch
.
int32
)
)
)
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
+
draft_token_num
metadata
.
max_seq_len_k
=
(
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
metadata
.
cu_seqlens_k
.
copy_
(
metadata
.
cu_seqlens_k
.
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
torch
.
cumsum
(
...
@@ -972,14 +919,19 @@ class FlashAttentionMultiStepBackend:
...
@@ -972,14 +919,19 @@ class FlashAttentionMultiStepBackend:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
assert
(
self
.
topk
==
1
),
"speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
self
.
attn_backends
=
[]
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
FlashAttentionBackend
(
FlashAttentionBackend
(
model_runner
,
model_runner
,
speculative_step_id
=
i
,
topk
=
self
.
topk
,
topk
=
self
.
topk
,
speculative_num_steps
=
self
.
speculative_num_steps
,
speculative_num_steps
=
self
.
speculative_num_steps
,
step_id
=
i
,
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment