Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7cd7430
Unverified
Commit
b7cd7430
authored
Aug 07, 2025
by
PGFLMG
Committed by
GitHub
Aug 06, 2025
Browse files
[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
parent
a69b6370
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2121 additions
and
4 deletions
+2121
-4
examples/runtime/engine/offline_batch_inference_qwen_1m.py
examples/runtime/engine/offline_batch_inference_qwen_1m.py
+74
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+28
-0
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+3
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+30
-3
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+1700
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+225
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+6
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+6
-0
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+6
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-0
No files found.
examples/runtime/engine/offline_batch_inference_qwen_1m.py
0 → 100644
View file @
b7cd7430
"""
Usage:
python3 offline_batch_inference.py
"""
from
urllib.request
import
urlopen
import
sglang
as
sgl
def
load_prompt
()
->
str
:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with
urlopen
(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/64k.txt"
,
timeout
=
5
,
)
as
response
:
prompt
=
response
.
read
().
decode
(
"utf-8"
)
return
prompt
# Processing the prompt.
def
process_requests
(
llm
:
sgl
.
Engine
,
prompts
:
list
[
str
])
->
None
:
# Create a sampling params object.
sampling_params
=
{
"temperature"
:
0.7
,
"top_p"
:
0.8
,
"top_k"
:
20
,
"repetition_penalty"
:
1.05
,
"max_new_tokens"
:
256
,
}
# Generate texts from the prompts.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt_token_ids
=
output
[
"meta_info"
][
"prompt_tokens"
]
generated_text
=
output
[
"text"
]
print
(
f
"Prompt length:
{
prompt_token_ids
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
# Create an LLM.
def
initialize_engine
()
->
sgl
.
Engine
:
llm
=
sgl
.
Engine
(
model_path
=
"Qwen/Qwen2.5-7B-Instruct-1M"
,
context_length
=
1048576
,
page_size
=
256
,
attention_backend
=
"dual_chunk_flash_attn"
,
tp_size
=
4
,
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
,
enable_torch_compile
=
False
,
chunked_prefill_size
=
131072
,
mem_fraction_static
=
0.6
,
log_level
=
"DEBUG"
,
)
return
llm
def
main
():
llm
=
initialize_engine
()
prompt
=
load_prompt
()
process_requests
(
llm
,
[
prompt
])
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/configs/model_config.py
View file @
b7cd7430
...
...
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
get_context_length
,
get_generation_config
,
get_hf_text_config
,
get_sparse_attention_config
,
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -270,6 +271,9 @@ class ModelConfig:
# Verify quantization
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
...
...
@@ -297,6 +301,13 @@ class ModelConfig:
**
kwargs
,
)
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
num_attention_heads
def
get_num_attention_heads
(
self
,
tensor_parallel_size
)
->
int
:
total_num_attention_heads
=
self
.
num_attention_heads
return
max
(
1
,
total_num_attention_heads
//
tensor_parallel_size
)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
...
...
@@ -484,6 +495,23 @@ class ModelConfig:
self
.
quantization
,
)
def
_verify_dual_chunk_attention_config
(
self
)
->
None
:
if
hasattr
(
self
.
hf_config
,
"dual_chunk_attention_config"
):
# Try loading the sparse attention config
sparse_attn_config
=
get_sparse_attention_config
(
self
.
model_path
)
if
not
sparse_attn_config
:
return
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_config"
]
=
(
sparse_attn_config
)
if
(
"sparse_attention_enabled"
not
in
self
.
hf_config
.
dual_chunk_attention_config
):
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_enabled"
]
=
True
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
if
eos_ids
is
not
None
:
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
b7cd7430
...
...
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
b7cd7430
...
...
@@ -14,10 +14,11 @@
"""Utilities for Huggingface Transformers."""
import
contextlib
import
json
import
os
import
warnings
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
import
torch
from
huggingface_hub
import
snapshot_download
...
...
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig
.
register
(
name
,
cls
)
def
download_from_hf
(
model_path
:
str
):
def
download_from_hf
(
model_path
:
str
,
allow_patterns
:
Optional
[
Union
[
str
,
list
]]
=
None
,
):
if
os
.
path
.
exists
(
model_path
):
return
model_path
return
snapshot_download
(
model_path
,
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
])
if
not
allow_patterns
:
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
]
return
snapshot_download
(
model_path
,
allow_patterns
=
allow_patterns
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
...
...
@@ -171,6 +178,26 @@ def get_generation_config(
return
None
# Qwen-1M related
def
get_sparse_attention_config
(
model
:
str
,
sparse_attention_config_filename
:
str
=
"sparse_attention_config.json"
,
)
->
Dict
[
str
,
Any
]:
is_local
=
os
.
path
.
isdir
(
model
)
if
not
is_local
:
# Download the config files.
model
=
download_from_hf
(
model
,
allow_patterns
=
[
"*.json"
])
config_file
=
os
.
path
.
join
(
model
,
sparse_attention_config_filename
)
if
not
os
.
path
.
exists
(
config_file
):
return
{}
# Load the sparse attention config.
with
open
(
config_file
)
as
f
:
config
=
json
.
load
(
f
)
return
config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
...
...
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
0 → 100644
View file @
b7cd7430
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with Dual chunk flash attention and sparse attention.
"""
import
functools
import
logging
import
math
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.sparse_flash_attn
import
(
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes_mergehead
,
sparse_attn_func
,
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
DualChunkFlashAttentionMetadata
:
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
=
None
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_seq_len
:
int
=
None
# (batch_size,). The orig sequence length per sequence.
orig_seq_lens
:
Optional
[
List
[
int
]]
=
None
# orig_seq_lens stored as a tensor.
orig_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Length scaling factor
scaling_factor
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for intra attention.
seq_lens_intra
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for intra attention.
max_seq_len_intra
:
Optional
[
int
]
=
None
# (batch_size, num_blocks). Block table for intra attention.
block_tables_intra
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for succ attention.
seq_lens_succ
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for succ attention.
max_seq_len_succ
:
Optional
[
int
]
=
None
# (batch_size, num_blocks). Block table for succ attention.
block_tables_succ
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for inter attention.
seq_lens_inter
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for inter attention.
max_seq_len_inter
:
Optional
[
int
]
=
None
class
DualChunkFlashAttentionBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
"ModelRunner"
,
)
->
None
:
self
.
forward_metadata
:
FlashAttentionMetadata
=
None
self
.
device
=
model_runner
.
device
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
num_heads
=
model_runner
.
model_config
.
get_num_attention_heads
(
model_runner
.
server_args
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
server_args
.
tp_size
)
self
.
head_size
=
model_runner
.
model_config
.
head_dim
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
self
.
kv_cache_dtype_str
=
model_runner
.
server_args
.
kv_cache_dtype
self
.
page_size
=
model_runner
.
page_size
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
dual_chunk_attention_config
=
getattr
(
model_runner
.
model_config
.
hf_config
,
"dual_chunk_attention_config"
,
None
)
assert
dual_chunk_attention_config
is
not
None
self
.
chunk_size
=
dual_chunk_attention_config
.
get
(
"chunk_size"
,
8192
)
self
.
local_size
=
dual_chunk_attention_config
.
get
(
"local_size"
,
1024
)
self
.
original_max_position_embeddings
=
dual_chunk_attention_config
.
get
(
"original_max_position_embeddings"
,
0
)
self
.
sparse_attention_config
=
dual_chunk_attention_config
.
get
(
"sparse_attention_config"
,
None
)
if
not
self
.
sparse_attention_config
:
logger
.
warning_once
(
"Sparse attention will not be enabled as "
"sparse attention config is not provided."
)
self
.
sparse_attention_enabled
=
dual_chunk_attention_config
.
get
(
"sparse_attention_enabled"
,
self
.
sparse_attention_config
is
not
None
)
self
.
sparse_attention_threshold
=
dual_chunk_attention_config
.
get
(
"sparse_attention_threshold"
,
32768
)
self
.
sparse_attention_last_q
=
dual_chunk_attention_config
.
get
(
"sparse_attention_last_q"
,
64
)
self
.
dual_chunk_attention_config
=
dual_chunk_attention_config
if
self
.
sparse_attention_enabled
:
self
.
arange
=
torch
.
arange
(
self
.
sparse_attention_last_q
,
device
=
"cuda"
)
self
.
last_q_mask
=
(
self
.
arange
[
None
,
None
,
:,
None
]
>=
self
.
arange
[
None
,
None
,
None
,
:]
)
@
functools
.
lru_cache
()
def
get_sparse_attention_config
(
self
,
layer_idx
)
->
List
[
Dict
[
str
,
Any
]]:
layer_sparse_attention_config
=
{
int
(
i
):
j
for
i
,
j
in
self
.
sparse_attention_config
[
layer_idx
].
items
()
}
start_head
=
self
.
num_heads
*
get_tensor_model_parallel_rank
()
end_head
=
start_head
+
self
.
num_heads
return
[
layer_sparse_attention_config
[
i
]
for
i
in
range
(
start_head
,
end_head
)]
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
forward_mode
:
ForwardMode
=
forward_batch
.
forward_mode
assert
forward_mode
.
is_prefill
()
or
forward_mode
.
is_decode
()
batch_size
=
forward_batch
.
batch_size
metadata
=
DualChunkFlashAttentionMetadata
()
metadata
.
seq_lens_tensor
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
metadata
.
seq_lens
=
forward_batch
.
seq_lens
.
tolist
()
metadata
.
max_seq_len
=
forward_batch
.
seq_lens
.
max
().
item
()
metadata
.
orig_seq_lens_tensor
=
forward_batch
.
orig_seq_lens
metadata
.
orig_seq_lens
=
forward_batch
.
orig_seq_lens
.
tolist
()
metadata
.
block_tables
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len
]
# Convert the block table to a strided format.
if
self
.
page_size
>
1
:
strided_indices
=
torch
.
arange
(
0
,
metadata
.
block_tables
.
shape
[
1
],
self
.
page_size
,
device
=
self
.
device
)
metadata
.
block_tables
=
(
metadata
.
block_tables
[:,
strided_indices
]
//
self
.
page_size
)
metadata
.
query_start_loc
=
torch
.
zeros
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
metadata
.
seq_lens_tensor
.
device
)
if
forward_mode
.
is_prefill
():
metadata
.
query_start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
.
to
(
torch
.
int32
),
dim
=
0
,
dtype
=
torch
.
int32
)
else
:
metadata
.
query_start_loc
[
1
:]
=
torch
.
cumsum
(
torch
.
arange
(
batch_size
,
dtype
=
metadata
.
query_start_loc
.
dtype
,
device
=
metadata
.
query_start_loc
.
device
,
),
dim
=
0
,
dtype
=
torch
.
int32
,
)
metadata
.
seq_start_loc
=
torch
.
zeros
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
metadata
.
seq_lens_tensor
.
device
)
metadata
.
seq_start_loc
[
1
:]
=
torch
.
cumsum
(
metadata
.
seq_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
)
if
self
.
original_max_position_embeddings
>
0
:
if
forward_mode
.
is_prefill
():
metadata
.
scaling_factor
=
(
0.1
*
torch
.
log
(
metadata
.
orig_seq_lens_tensor
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
else
:
metadata
.
scaling_factor
=
(
0.1
*
torch
.
log
(
metadata
.
orig_seq_lens_tensor
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
if
forward_mode
.
is_decode
():
cache_seq_lens
=
metadata
.
orig_seq_lens_tensor
chunk_len
=
self
.
chunk_size
-
self
.
local_size
chunk_num_curr
=
(
cache_seq_lens
-
1
)
//
chunk_len
seq_lens_intra
=
cache_seq_lens
-
chunk_num_curr
*
chunk_len
max_seq_len_intra
=
seq_lens_intra
.
max
().
item
()
metadata
.
seq_lens_intra
=
seq_lens_intra
metadata
.
max_seq_len_intra
=
max_seq_len_intra
block_tables_intra
=
torch
.
zeros
(
batch_size
,
(
max_seq_len_intra
-
1
)
//
self
.
page_size
+
1
,
dtype
=
metadata
.
block_tables
.
dtype
,
device
=
metadata
.
block_tables
.
device
,
)
for
i
in
range
(
batch_size
):
st
=
chunk_num_curr
[
i
]
*
chunk_len
//
self
.
page_size
ed
=
min
(
st
+
(
max_seq_len_intra
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
block_tables_intra
[
i
,
:
ed
-
st
]
=
metadata
.
block_tables
[
i
,
st
:
ed
]
metadata
.
block_tables_intra
=
block_tables_intra
metadata
.
seq_lens_succ
=
(
chunk_num_curr
-
(
chunk_num_curr
-
1
).
clip
(
min
=
0
)
)
*
chunk_len
metadata
.
max_seq_len_succ
=
metadata
.
seq_lens_succ
.
max
().
item
()
if
metadata
.
max_seq_len_succ
:
block_tables_succ
=
torch
.
zeros
(
batch_size
,
(
metadata
.
max_seq_len_succ
-
1
)
//
self
.
page_size
+
1
,
dtype
=
metadata
.
block_tables
.
dtype
,
device
=
metadata
.
block_tables
.
device
,
)
for
i
in
range
(
batch_size
):
start
=
(
(
chunk_num_curr
[
i
]
-
1
).
clip
(
min
=
0
)
*
chunk_len
//
self
.
page_size
)
end
=
min
(
start
+
(
metadata
.
max_seq_len_succ
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
block_tables_succ
[
i
,
:
end
-
start
]
=
metadata
.
block_tables
[
i
,
start
:
end
]
metadata
.
block_tables_succ
=
block_tables_succ
metadata
.
seq_lens_inter
=
(
chunk_num_curr
-
1
).
clip
(
min
=
0
)
*
chunk_len
metadata
.
max_seq_len_inter
=
metadata
.
seq_lens_inter
.
max
().
item
()
self
.
forward_metadata
=
metadata
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
# Use precomputed metadata across all layers
metadata
=
self
.
forward_metadata
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
)
=
torch
.
split
(
q
,
q
.
shape
[
-
1
]
//
5
,
dim
=-
1
)
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ
=
query_succ
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter
=
query_inter
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ_critical
=
query_succ_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter_critical
=
query_inter_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
# apply DCA scaling
if
self
.
original_max_position_embeddings
>
0
:
assert
metadata
.
scaling_factor
is
not
None
assert
metadata
.
query_start_loc
is
not
None
assert
metadata
.
orig_seq_lens
is
not
None
current_start
=
0
query_start_loc_cpu
=
metadata
.
query_start_loc
.
cpu
()
for
i
in
range
(
len
(
metadata
.
orig_seq_lens
)):
current_end
=
(
current_start
+
(
query_start_loc_cpu
[
i
+
1
]
-
query_start_loc_cpu
[
i
]).
item
()
)
key
[
current_start
:
current_end
].
mul_
(
metadata
.
scaling_factor
[
i
])
current_start
=
current_end
assert
current_end
<=
self
.
max_context_len
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
if
key
is
not
None
and
value
is
not
None
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
key
,
value
,
layer
.
k_scale
,
layer
.
v_scale
,
)
if
not
save_kv_cache
:
# profile run
o
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
metadata
.
seq_start_loc
,
cu_seqlens_k
=
metadata
.
seq_start_loc
,
max_seqlen_q
=
metadata
.
max_seq_len
,
max_seqlen_k
=
metadata
.
max_seq_len
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
)
else
:
# prefill/chunked-prefill
# get per layer sparse attention config
if
self
.
sparse_attention_enabled
:
self
.
layer_sparse_attention_config
=
self
.
get_sparse_attention_config
(
layer
.
layer_id
)
assert
metadata
.
orig_seq_lens
is
not
None
o
=
self
.
_dual_chunk_flash_attn_prefill
(
q
=
query
,
q_succ
=
query_succ
,
q_inter
=
query_inter
,
q_succ_critical
=
query_succ_critical
,
q_inter_critical
=
query_inter_critical
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
metadata
.
query_start_loc
,
cu_seqlens_k
=
metadata
.
seq_start_loc
,
orig_seq_lens
=
metadata
.
orig_seq_lens
,
scaling_factor
=
metadata
.
scaling_factor
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
(
-
1
,
-
1
),
block_table
=
metadata
.
block_tables
,
chunk_size
=
self
.
chunk_size
,
local_size
=
self
.
local_size
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
)
->
torch
.
Tensor
:
# Use precomputed metadata across all layers
metadata
=
self
.
forward_metadata
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
)
=
torch
.
split
(
q
,
q
.
shape
[
-
1
]
//
5
,
dim
=-
1
)
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ
=
query_succ
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter
=
query_inter
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ_critical
=
query_succ_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter_critical
=
query_inter_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
if
key
is
not
None
and
value
is
not
None
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
key
,
value
,
layer
.
k_scale
,
layer
.
v_scale
,
)
# apply DCA scaling
if
self
.
original_max_position_embeddings
>
0
:
assert
metadata
.
scaling_factor
is
not
None
scaling_factor
=
metadata
.
scaling_factor
key
.
mul_
(
scaling_factor
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
))
o
=
self
.
_dual_chunk_flash_attn_decoding
(
query
.
unsqueeze
(
1
),
query_succ
.
unsqueeze
(
1
),
query_inter
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
block_table
=
metadata
.
block_tables
,
cache_seqlens
=
metadata
.
seq_lens_tensor
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
chunk_size
=
self
.
chunk_size
,
local_size
=
self
.
local_size
,
original_max_position_embeddings
=
self
.
original_max_position_embeddings
,
decode_meta
=
metadata
,
).
squeeze
(
1
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self
.
decode_metadata
=
{
"seq_lens_tensor"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"orig_seq_lens_tensor"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"scaling_factor"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
float32
,
device
=
self
.
device
),
"block_tables"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
-
1
)
//
self
.
page_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"block_tables_intra"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
-
1
)
//
self
.
page_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"seq_lens_intra"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"block_tables_succ"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
-
1
)
//
self
.
page_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"seq_lens_succ"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"seq_lens_inter"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
}
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
None
],
):
metadata
=
DualChunkFlashAttentionMetadata
()
if
forward_mode
.
is_decode_or_idle
():
if
self
.
original_max_position_embeddings
>
0
:
metadata
.
scaling_factor
=
self
.
decode_metadata
[
"scaling_factor"
][:
bs
]
metadata
.
seq_lens_tensor
=
self
.
decode_metadata
[
"seq_lens_tensor"
][:
bs
]
metadata
.
orig_seq_lens_tensor
=
self
.
decode_metadata
[
"orig_seq_lens_tensor"
][:
bs
]
metadata
.
max_seq_len
=
self
.
max_context_len
metadata
.
block_tables
=
self
.
decode_metadata
[
"block_tables"
][
req_pool_indices
,
:
]
# intra
metadata
.
max_seq_len_intra
=
self
.
max_context_len
metadata
.
seq_lens_intra
=
self
.
decode_metadata
[
"seq_lens_intra"
][:
bs
]
metadata
.
block_tables_intra
=
self
.
decode_metadata
[
"block_tables_intra"
][
:
bs
,
:
]
# succ
metadata
.
seq_lens_succ
=
self
.
decode_metadata
[
"seq_lens_succ"
][:
bs
]
metadata
.
max_seq_len_succ
=
self
.
max_context_len
metadata
.
block_tables_succ
=
self
.
decode_metadata
[
"block_tables_succ"
][
:
bs
,
:
]
metadata
.
seq_lens_inter
=
self
.
decode_metadata
[
"seq_lens_inter"
][:
bs
]
metadata
.
max_seq_len_inter
=
self
.
max_context_len
self
.
decode_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
None
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
out_cache_loc
:
torch
.
Tensor
=
None
,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert
forward_mode
.
is_decode
()
seq_lens
=
seq_lens
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
metadata
=
self
.
decode_metadata
[
bs
]
metadata
.
seq_lens_tensor
.
copy_
(
seq_lens
.
to
(
torch
.
int32
))
metadata
.
seq_lens
=
seq_lens
.
tolist
()
metadata
.
max_seq_len
=
seq_lens
.
max
().
item
()
metadata
.
orig_seq_lens_tensor
.
copy_
(
seq_lens
)
metadata
.
orig_seq_lens
=
seq_lens
.
tolist
()
block_tables
=
self
.
req_to_token
[
req_pool_indices
,
:
metadata
.
max_seq_len
]
# Convert the block table to a strided format.
if
self
.
page_size
>
1
:
strided_indices
=
torch
.
arange
(
0
,
block_tables
.
shape
[
1
],
self
.
page_size
,
device
=
self
.
device
)
block_tables
=
block_tables
[:,
strided_indices
]
//
self
.
page_size
metadata
.
block_tables
.
fill_
(
0
)
metadata
.
block_tables
[:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]].
copy_
(
block_tables
)
if
self
.
original_max_position_embeddings
>
0
:
scaling_factor
=
(
0.1
*
torch
.
log
(
metadata
.
orig_seq_lens_tensor
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
metadata
.
scaling_factor
.
copy_
(
scaling_factor
)
cache_seq_lens
=
metadata
.
orig_seq_lens_tensor
chunk_len
=
self
.
chunk_size
-
self
.
local_size
chunk_num_curr
=
(
cache_seq_lens
-
1
)
//
chunk_len
seq_lens_intra
=
cache_seq_lens
-
chunk_num_curr
*
chunk_len
max_seq_len_intra
=
seq_lens_intra
.
max
().
item
()
metadata
.
seq_lens_intra
.
copy_
(
seq_lens_intra
)
metadata
.
max_seq_len_intra
=
max_seq_len_intra
metadata
.
block_tables_intra
.
fill_
(
0
)
for
i
in
range
(
bs
):
st
=
chunk_num_curr
[
i
]
*
chunk_len
//
self
.
page_size
ed
=
min
(
st
+
(
max_seq_len_intra
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
metadata
.
block_tables_intra
[
i
,
:
ed
-
st
]
=
metadata
.
block_tables
[
i
,
st
:
ed
]
seq_lens_succ
=
(
chunk_num_curr
-
(
chunk_num_curr
-
1
).
clip
(
min
=
0
))
*
chunk_len
metadata
.
seq_lens_succ
.
copy_
(
seq_lens_succ
)
metadata
.
max_seq_len_succ
=
metadata
.
seq_lens_succ
.
max
().
item
()
if
metadata
.
max_seq_len_succ
:
metadata
.
block_tables_succ
.
fill_
(
0
)
for
i
in
range
(
bs
):
start
=
(
(
chunk_num_curr
[
i
]
-
1
).
clip
(
min
=
0
)
*
chunk_len
//
self
.
page_size
)
end
=
min
(
start
+
(
metadata
.
max_seq_len_succ
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
metadata
.
block_tables_succ
[
i
,
:
end
-
start
]
=
metadata
.
block_tables
[
i
,
start
:
end
]
seq_lens_inter
=
(
chunk_num_curr
-
1
).
clip
(
min
=
0
)
*
chunk_len
metadata
.
seq_lens_inter
.
copy_
(
seq_lens_inter
)
metadata
.
max_seq_len_inter
=
metadata
.
seq_lens_inter
.
max
().
item
()
self
.
forward_metadata
=
metadata
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for sequence length in CUDA graph."""
return
1
def
_dual_chunk_flash_attn_prefill
(
self
,
q
,
q_succ
,
q_inter
,
q_succ_critical
,
q_inter_critical
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
orig_seq_lens
:
List
[
int
],
scaling_factor
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
Optional
[
bool
]
=
True
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
int
=
8192
,
local_size
:
int
=
1024
,
):
if
not
causal
:
raise
ValueError
(
"Dual Chunk Attention does not support causal=False"
)
if
window_size
!=
(
-
1
,
-
1
):
raise
ValueError
(
"Dual Chunk Attention does not support window_size"
)
cu_seqlens_q_cpu
=
cu_seqlens_q
.
cpu
().
tolist
()
cu_seqlens_k_cpu
=
cu_seqlens_k
.
cpu
().
tolist
()
all_outputs
=
[]
for
i
in
range
(
0
,
len
(
cu_seqlens_q_cpu
)
-
1
):
qs
=
cu_seqlens_q_cpu
[
i
]
qe
=
cu_seqlens_q_cpu
[
i
:
i
+
2
][
-
1
]
ks
=
cu_seqlens_k_cpu
[
i
]
ke
=
cu_seqlens_k_cpu
[
i
:
i
+
2
][
-
1
]
current_q
=
q
[
qs
:
qe
]
current_q_succ
=
q_succ
[
qs
:
qe
]
current_q_inter
=
q_inter
[
qs
:
qe
]
current_q_succ_critical
=
q_succ_critical
[
qs
:
qe
]
current_q_inter_critical
=
q_inter_critical
[
qs
:
qe
]
if
block_table
is
None
:
current_k
=
k
[
ks
:
ke
]
current_v
=
v
[
ks
:
ke
]
current_block_table
=
None
current_orig_seq_len
=
orig_seq_lens
[
i
]
else
:
current_block_table
=
block_table
[
i
]
current_orig_seq_len
=
orig_seq_lens
[
i
]
current_k
=
k
current_v
=
v
sparse_attn_enabled
=
(
self
.
sparse_attention_enabled
and
current_orig_seq_len
>
self
.
sparse_attention_threshold
)
if
current_q
.
shape
[
0
]
==
0
:
continue
if
current_k
.
shape
[
0
]
==
0
:
all_outputs
.
append
(
torch
.
zeros
(
(
current_q
.
shape
[
0
],
current_q
.
shape
[
1
],
v
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
q
.
dtype
,
)
)
continue
current_output
=
torch
.
empty_like
(
current_q
)
group_size
=
int
(
current_q
.
size
(
-
2
)
/
current_k
.
size
(
-
2
))
if
sparse_attn_enabled
:
num_device_q_heads
=
current_q
.
size
(
-
2
)
heads_vertical_size
=
torch
.
empty
(
size
=
(
num_device_q_heads
,),
dtype
=
torch
.
int32
)
heads_slash_size
=
torch
.
empty
(
size
=
(
num_device_q_heads
,),
dtype
=
torch
.
int32
)
for
head_id
in
range
(
current_q
.
size
(
-
2
)):
(
ty
,
vertical_size
,
slash_size
,
_
,
)
=
self
.
layer_sparse_attention_config
[
head_id
]
assert
ty
==
"vertical_and_slash"
,
"only support slash mode"
if
vertical_size
==
30
:
vertical_size
+=
100
heads_vertical_size
[
head_id
]
=
vertical_size
heads_slash_size
[
head_id
]
=
slash_size
current_output
=
self
.
_dual_chunk_flash_attn_prefill_func
(
current_q
,
# allheads
current_q_succ
,
current_q_inter
,
current_q_succ_critical
,
current_q_inter_critical
,
current_k
,
current_v
,
current_block_table
,
softmax_scale
,
chunk_size
,
local_size
,
scaling_factor
[
i
].
item
(),
ke
-
ks
,
sparse_attn_enabled
=
sparse_attn_enabled
,
heads_vertical_size
=
heads_vertical_size
,
heads_slash_size
=
heads_slash_size
,
group_size
=
group_size
,
)
else
:
for
head_id
in
range
(
current_q
.
size
(
-
2
)):
# (seq_len, num_heads, head_size)
current_q_head
=
current_q
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_succ_head
=
current_q_succ
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_inter_head
=
current_q_inter
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_succ_head_critical
=
current_q_succ_critical
[
:,
head_id
,
:
].
unsqueeze
(
1
)
current_q_inter_head_critical
=
current_q_inter_critical
[
:,
head_id
,
:
].
unsqueeze
(
1
)
if
block_table
is
not
None
:
current_k_head
=
current_k
[
...,
head_id
//
group_size
,
:
].
unsqueeze
(
2
)
current_v_head
=
current_v
[
...,
head_id
//
group_size
,
:
].
unsqueeze
(
2
)
else
:
current_k_head
=
current_k
[:,
head_id
,
:].
unsqueeze
(
1
)
current_v_head
=
current_v
[:,
head_id
,
:].
unsqueeze
(
1
)
current_out
=
self
.
_dual_chunk_flash_attn_prefill_func
(
current_q_head
,
current_q_succ_head
,
current_q_inter_head
,
current_q_succ_head_critical
,
current_q_inter_head_critical
,
current_k_head
,
current_v_head
,
current_block_table
,
softmax_scale
,
chunk_size
,
local_size
,
scaling_factor
[
i
].
item
(),
ke
-
ks
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
current_output
[:,
head_id
:
head_id
+
1
,
:]
=
current_out
all_outputs
.
append
(
current_output
)
return
torch
.
cat
(
all_outputs
,
dim
=
0
)
def
_dual_chunk_flash_attn_prefill_func
(
self
,
q
,
q_succ
,
q_inter
,
q_succ_critical
,
q_inter_critical
,
k
,
v
,
block_table
,
softmax_scale
:
float
,
chunk_size
:
int
,
local_size
:
int
,
scaling_factor
:
float
,
k_length
:
int
,
sparse_attn_enabled
:
Optional
[
bool
]
=
True
,
heads_vertical_size
=
None
,
heads_slash_size
=
None
,
group_size
=
None
,
):
flash_results
=
[]
chunk_len
=
chunk_size
-
local_size
if
block_table
is
not
None
:
block_size
=
v
.
shape
[
1
]
if
chunk_len
%
block_size
!=
0
:
raise
ValueError
(
"chunk_len must be divisible by block_size."
)
else
:
block_size
=
1
if
self
.
original_max_position_embeddings
>
0
:
softmax_scale
=
softmax_scale
*
scaling_factor
begin
=
k_length
-
q
.
shape
[
0
]
while
begin
<
k_length
:
flash_per_chunk
=
[]
prev_chunk_end_pos
=
(
begin
//
chunk_len
)
*
chunk_len
next_chunk_end_pos
=
prev_chunk_end_pos
+
chunk_len
end
=
min
(
next_chunk_end_pos
,
k_length
)
qbegin
=
begin
-
(
k_length
-
q
.
shape
[
0
])
qend
=
end
-
(
k_length
-
q
.
shape
[
0
])
qk_chunks
=
[]
q_states_intra
=
q
[
qbegin
:
qend
]
# choose critical token
if
block_table
is
not
None
:
block_tables_intra
=
_get_block
(
block_table
,
block_size
,
prev_chunk_end_pos
,
end
)
k_states_intra
=
k
[
block_tables_intra
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[
:
(
end
-
prev_chunk_end_pos
)
]
v_states_intra
=
v
[
block_tables_intra
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[
:
(
end
-
prev_chunk_end_pos
)
]
else
:
block_tables_intra
=
None
k_states_intra
=
k
[
prev_chunk_end_pos
:
end
]
v_states_intra
=
v
[
prev_chunk_end_pos
:
end
]
if
sparse_attn_enabled
:
last_q_size
=
min
(
qend
-
qbegin
,
self
.
sparse_attention_last_q
)
_
,
num_device_k_heads
,
head_dim
=
k_states_intra
.
shape
k_states_intra
=
(
k_states_intra
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
v_states_intra
=
(
v_states_intra
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
qk_chunks
.
append
(
(
q_states_intra
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_intra
.
permute
(
1
,
2
,
0
)
)
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
q_states_succ
=
q_succ
[
qbegin
:
qend
]
q_states_succ_critical
=
q_succ_critical
[
qbegin
:
qend
]
if
block_table
is
not
None
:
block_tables_succ
=
_get_block
(
block_table
,
block_size
,
prev_chunk_end_pos
-
chunk_len
,
prev_chunk_end_pos
,
)
k_states_succ
=
k
[
block_tables_succ
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[
:
chunk_len
]
v_states_succ
=
v
[
block_tables_succ
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[
:
chunk_len
]
else
:
k_states_succ
=
k
[
prev_chunk_end_pos
-
chunk_len
:
prev_chunk_end_pos
]
v_states_succ
=
v
[
prev_chunk_end_pos
-
chunk_len
:
prev_chunk_end_pos
]
if
sparse_attn_enabled
:
k_states_succ
=
(
k_states_succ
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
v_states_succ
=
(
v_states_succ
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
qk_chunks
.
append
(
(
q_states_succ_critical
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_succ
.
permute
(
1
,
2
,
0
)
)
if
prev_chunk_end_pos
-
chunk_len
*
2
>=
0
:
q_states_inter
=
q_inter
[
qbegin
:
qend
]
q_states_inter_critical
=
q_inter_critical
[
qbegin
:
qend
]
if
block_table
is
not
None
:
block_tables_inter
=
_get_block
(
block_table
,
block_size
,
0
,
prev_chunk_end_pos
-
chunk_len
)
k_states_inter
=
k
[
block_tables_inter
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[
:
(
prev_chunk_end_pos
-
chunk_len
)
]
v_states_inter
=
v
[
block_tables_inter
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[
:
(
prev_chunk_end_pos
-
chunk_len
)
]
else
:
k_states_inter
=
k
[:
prev_chunk_end_pos
-
chunk_len
]
v_states_inter
=
v
[:
prev_chunk_end_pos
-
chunk_len
]
if
sparse_attn_enabled
:
k_states_inter
=
(
k_states_inter
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
v_states_inter
=
(
v_states_inter
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
qk_chunks
.
append
(
(
q_states_inter_critical
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_inter
.
permute
(
1
,
2
,
0
)
)
if
sparse_attn_enabled
:
reversed_qk
=
qk_chunks
[::
-
1
]
qk
=
torch
.
cat
(
reversed_qk
,
dim
=-
1
)
qk
[:,
:,
-
last_q_size
:]
=
torch
.
where
(
self
.
last_q_mask
[...,
-
last_q_size
:,
-
last_q_size
:].
to
(
qk
.
device
),
qk
[:,
:,
-
last_q_size
:],
-
torch
.
inf
,
)
qk
=
F
.
softmax
(
qk
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vertical
=
qk
.
sum
(
-
2
,
keepdim
=
True
)
vertical
[...,
:
30
]
=
torch
.
inf
# Avoid sorting by using the min/max ints to fill the indexer
# buffers.
int32_max
=
torch
.
iinfo
(
torch
.
int32
).
max
int32_min
=
torch
.
iinfo
(
torch
.
int32
).
min
n_heads
=
qk
.
size
()[
0
]
max_slash_topk
=
torch
.
max
(
heads_slash_size
).
item
()
max_vertical_topk
=
torch
.
max
(
heads_vertical_size
).
item
()
# store each head's slash topk, vertical topk
vertical
=
vertical
.
reshape
((
n_heads
,
-
1
))
# prevent out of range when prompt size < max_vertical_topk
max_vertical_topk
=
min
(
vertical
.
shape
[
-
1
],
max_vertical_topk
)
vertical_topk_buffer
=
torch
.
topk
(
vertical
,
max_vertical_topk
,
-
1
).
indices
slash_topk_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
max_slash_topk
),
dtype
=
torch
.
int64
,
device
=
qk
.
device
)
for
head_i
in
range
(
n_heads
):
# (nqheads=1, lastq, k_len)
head_score
=
qk
[
head_i
:
head_i
+
1
,
:,
:]
slash_scores
=
_sum_all_diagonal_matrix
(
head_score
)
if
head_score
.
size
(
1
)
!=
1
:
# drop right up corner
slash_scores
=
slash_scores
[...,
:
-
last_q_size
+
1
]
slash_scores
[...,
-
100
:]
=
torch
.
inf
head_slash_size
=
heads_slash_size
[
head_i
]
head_slash_size
=
min
(
head_slash_size
,
vertical
.
size
(
-
1
))
slash_topk
=
torch
.
topk
(
slash_scores
,
head_slash_size
,
-
1
).
indices
# (nheads, max_topk)
slash_topk_buffer
[
head_i
,
:
head_slash_size
]
=
slash_topk
# reset heads topk
heads_slash_size
[
head_i
]
=
head_slash_size
heads_vertical_size
[
head_i
]
=
min
(
heads_vertical_size
[
head_i
],
max_vertical_topk
)
# store
vertical_buffer
=
torch
.
full
(
(
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
slash_buffer
=
torch
.
full
(
(
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
succ_vertical_buffer
=
torch
.
full
(
(
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
succ_slash_buffer
=
torch
.
full
(
(
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
inter_vertical_buffer
=
torch
.
full
(
(
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
inter_slash_buffer
=
torch
.
full
(
(
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
succ_vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
succ_slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
inter_vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
inter_slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
for
head_i
in
range
(
n_heads
):
vertical_topk
=
vertical_topk_buffer
[
head_i
,
:
heads_vertical_size
[
head_i
]
]
# intra
intra_vertical_indices
=
(
vertical_topk
[
vertical_topk
>=
prev_chunk_end_pos
]
-
prev_chunk_end_pos
)
if
intra_vertical_indices
.
nelement
()
==
0
:
intra_vertical_indices
=
torch
.
cat
(
[
intra_vertical_indices
,
torch
.
arange
(
0
,
k_states_intra
.
size
(
0
),
max
(
1
,
k_states_intra
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
slash_topk
=
slash_topk_buffer
[
head_i
,
:
heads_slash_size
[
head_i
]]
intra_slash_indices
=
(
qk
.
size
(
-
1
)
-
1
)
-
slash_topk
[
slash_topk
>=
prev_chunk_end_pos
]
# fill buffer
v_count
=
intra_vertical_indices
.
nelement
()
s_count
=
intra_slash_indices
.
nelement
()
vertical_size_buffer
[
head_i
]
=
v_count
slash_sizes_buffer
[
head_i
]
=
s_count
vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
intra_vertical_indices
)
slash_buffer
[
head_i
,
:
s_count
].
copy_
(
intra_slash_indices
)
# succ
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
succ_vertical_indices
=
vertical_topk
[
(
vertical_topk
<
prev_chunk_end_pos
)
&
(
vertical_topk
>=
prev_chunk_end_pos
-
chunk_len
)
]
-
(
prev_chunk_end_pos
-
chunk_len
)
# TODO: support no vertical
if
succ_vertical_indices
.
nelement
()
==
0
:
succ_vertical_indices
=
torch
.
cat
(
[
succ_vertical_indices
,
torch
.
arange
(
0
,
k_states_succ
.
size
(
0
),
max
(
1
,
k_states_succ
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
succ_slash_indices
=
(
prev_chunk_end_pos
+
(
qend
-
qbegin
)
-
1
)
-
slash_topk
[
(
(
slash_topk
>=
(
prev_chunk_end_pos
-
chunk_len
))
&
(
slash_topk
<
(
prev_chunk_end_pos
+
(
qend
-
qbegin
)))
)
]
if
succ_slash_indices
.
nelement
()
==
0
:
succ_slash_indices
=
torch
.
cat
(
[
succ_slash_indices
,
torch
.
arange
(
0
,
k_states_succ
.
size
(
0
),
max
(
1
,
k_states_succ
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
# fill buffer
v_count
=
succ_vertical_indices
.
nelement
()
s_count
=
succ_slash_indices
.
nelement
()
succ_vertical_size_buffer
[
head_i
]
=
v_count
succ_slash_sizes_buffer
[
head_i
]
=
s_count
succ_vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
succ_vertical_indices
)
succ_slash_buffer
[
head_i
,
:
s_count
].
copy_
(
succ_slash_indices
)
if
prev_chunk_end_pos
-
2
*
chunk_len
>=
0
:
inter_vertical_indices
=
vertical_topk
[
vertical_topk
<
prev_chunk_end_pos
-
chunk_len
]
if
inter_vertical_indices
.
nelement
()
==
0
:
inter_vertical_indices
=
torch
.
cat
(
[
inter_vertical_indices
,
torch
.
arange
(
0
,
k_states_inter
.
size
(
0
),
max
(
1
,
k_states_inter
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
inter_slash_indices
=
(
prev_chunk_end_pos
-
chunk_len
+
(
qend
-
qbegin
)
-
1
)
-
slash_topk
[
slash_topk
<
(
prev_chunk_end_pos
-
chunk_len
+
(
qend
-
qbegin
))
]
if
inter_slash_indices
.
nelement
()
==
0
:
inter_slash_indices
=
torch
.
cat
(
[
inter_slash_indices
,
torch
.
arange
(
0
,
k_states_inter
.
size
(
0
),
max
(
1
,
k_states_inter
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
# fill buffer
v_count
=
inter_vertical_indices
.
nelement
()
s_count
=
inter_slash_indices
.
nelement
()
inter_vertical_size_buffer
[
head_i
]
=
v_count
inter_slash_sizes_buffer
[
head_i
]
=
s_count
inter_vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
inter_vertical_indices
)
inter_slash_buffer
[
head_i
,
:
s_count
].
copy_
(
inter_slash_indices
)
else
:
intra_vertical_indices
,
intra_slash_indices
=
None
,
None
succ_vertical_indices
,
succ_slash_indices
=
None
,
None
inter_vertical_indices
,
inter_slash_indices
=
None
,
None
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_intra
,
k_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
stage
=
"intra"
,
vertical_indices
=
vertical_buffer
,
slash_indices
=
slash_buffer
,
vertical_indices_count
=
vertical_size_buffer
,
slash_indices_count
=
slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_intra
,
k_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
stage
=
"intra"
,
vertical_indices
=
intra_vertical_indices
,
slash_indices
=
intra_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
flash_per_chunk
.
append
(
flash_result
)
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_succ
,
k_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_buffer
,
slash_indices
=
succ_slash_buffer
,
vertical_indices_count
=
succ_vertical_size_buffer
,
slash_indices_count
=
succ_slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_succ
,
k_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_indices
,
slash_indices
=
succ_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
flash_per_chunk
.
append
(
flash_result
)
if
prev_chunk_end_pos
-
chunk_len
*
2
>=
0
:
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_inter
,
k_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_buffer
,
slash_indices
=
inter_slash_buffer
,
vertical_indices_count
=
inter_vertical_size_buffer
,
slash_indices_count
=
inter_slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_inter
,
k_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_indices
,
slash_indices
=
inter_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
flash_per_chunk
.
append
(
flash_result
)
flash_results
.
append
(
flash_per_chunk
)
begin
=
end
attn_output
=
self
.
_merge_attn_outputs
(
flash_results
)
del
flash_results
return
attn_output
def
_do_flash_attn
(
self
,
query_states
:
torch
.
Tensor
,
key_states
:
torch
.
Tensor
,
value_states
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
=
True
,
max_seqlen_k
:
Optional
[
int
]
=
None
,
stage
:
str
=
"intra"
,
vertical_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
slash_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
vertical_indices_count
:
Optional
[
torch
.
Tensor
]
=
None
,
slash_indices_count
:
Optional
[
torch
.
Tensor
]
=
None
,
mergehead_softmax_scale
:
Optional
[
float
]
=
None
,
sparse_attn_enabled
:
Optional
[
bool
]
=
False
,
):
if
max_seqlen_k
is
None
:
max_seqlen_k
=
key_states
.
shape
[
0
]
q_len
=
query_states
.
shape
[
0
]
q_heads
=
query_states
.
shape
[
1
]
h_dim
=
query_states
.
shape
[
-
1
]
if
sparse_attn_enabled
:
assert
slash_indices
is
not
None
if
stage
==
"intra"
:
assert
causal
else
:
assert
not
causal
query_states
=
query_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
key_states
=
key_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
value_states
=
value_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
q
=
query_states
k
=
key_states
v
=
value_states
if
vertical_indices_count
is
not
None
and
slash_indices_count
is
not
None
:
assert
mergehead_softmax_scale
is
not
None
res
,
s_lse
=
_vertical_slash_sparse_attention
(
q
,
k
,
v
,
vertical_indices
,
slash_indices
,
mergehead_softmax_scale
,
causal
=
causal
,
stage
=
stage
,
vertical_indices_count
=
vertical_indices_count
,
slash_indices_count
=
slash_indices_count
,
)
res
=
res
.
view
(
q_heads
,
q_len
,
h_dim
).
transpose
(
0
,
1
)
# (qlen,nhead,h_dim)
s_lse
=
(
s_lse
.
view
(
q_heads
,
q_len
,
1
).
squeeze
(
-
1
).
unsqueeze
(
0
).
float
()
)
# (1, nhead,qlen)
else
:
res
,
s_lse
=
_vertical_slash_sparse_attention
(
q
,
k
,
v
,
vertical_indices
,
slash_indices
,
softmax_scale
,
causal
=
causal
,
stage
=
stage
,
)
res
=
res
.
view
(
q_len
,
q_heads
,
h_dim
)
s_lse
=
s_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
return
res
,
s_lse
output
,
softmax_lse
,
*
rest
=
flash_attn_varlen_func
(
q
=
query_states
,
k
=
key_states
,
v
=
value_states
,
softmax_scale
=
softmax_scale
,
cu_seqlens_q
=
torch
.
tensor
(
[
0
,
query_states
.
shape
[
0
]],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
,
),
max_seqlen_q
=
query_states
.
shape
[
0
],
cu_seqlens_k
=
torch
.
tensor
(
[
0
,
max_seqlen_k
],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
),
max_seqlen_k
=
max_seqlen_k
,
causal
=
causal
,
return_softmax_lse
=
True
,
)
softmax_lse
=
softmax_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
return
output
,
softmax_lse
def
_merge_attn_outputs
(
self
,
flash_results
:
List
[
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]],
return_lse
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
attn_outputs_all
=
[]
logits_all
=
[]
for
flash_per_chunk
in
flash_results
:
if
len
(
flash_per_chunk
)
==
1
:
attn_outputs_all
.
append
(
flash_per_chunk
[
0
][
0
])
if
return_lse
:
logits_all
.
append
(
flash_per_chunk
[
0
][
1
])
continue
attn_outputs
=
torch
.
stack
(
[
flash_attn_output
[
0
]
for
flash_attn_output
in
flash_per_chunk
]
)
logits
=
torch
.
stack
(
[
flash_attn_output
[
1
]
for
flash_attn_output
in
flash_per_chunk
]
)
logits
=
logits
.
to
(
torch
.
float32
)
if
return_lse
:
max_val
=
torch
.
max
(
logits
,
dim
=
0
).
values
diff
=
torch
.
abs
(
logits
[
0
]
-
logits
[
1
])
log_sum_exp
=
max_val
+
torch
.
log1p
(
torch
.
exp
(
-
diff
))
logits_all
.
append
(
log_sum_exp
)
max_logits
=
torch
.
max
(
logits
,
dim
=
0
).
values
stable_logits
=
logits
-
max_logits
.
unsqueeze
(
0
)
lse_s
=
torch
.
exp
(
stable_logits
).
detach
()
lse_sum
=
torch
.
sum
(
lse_s
,
dim
=
0
)
lse_s
/=
lse_sum
attn_outputs
*=
lse_s
.
unsqueeze
(
-
1
).
transpose
(
2
,
3
).
squeeze
(
1
)
attn_outputs_all
.
append
(
attn_outputs
.
sum
(
dim
=
0
))
if
return_lse
:
return
(
torch
.
cat
(
attn_outputs_all
,
dim
=
0
),
torch
.
cat
(
logits_all
,
dim
=-
1
))
else
:
return
torch
.
cat
(
attn_outputs_all
,
dim
=
0
)
def
_dual_chunk_flash_attn_decoding
(
self
,
query
:
torch
.
Tensor
,
query_succ
:
torch
.
Tensor
,
query_inter
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
,
chunk_size
:
int
,
local_size
:
int
,
original_max_position_embeddings
:
int
,
decode_meta
:
DualChunkFlashAttentionMetadata
,
):
if
not
causal
:
raise
ValueError
(
"Dual Chunk Attention does not support causal=False"
)
block_size
=
value_cache
.
shape
[
1
]
chunk_len
=
chunk_size
-
local_size
if
chunk_len
%
block_size
!=
0
:
raise
ValueError
(
"chunk_len must be divisible by block_size."
)
if
original_max_position_embeddings
>
0
:
assert
decode_meta
.
scaling_factor
is
not
None
scaling_factor
=
decode_meta
.
scaling_factor
query
=
(
query
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
# possible for numerical issue, need to fused in the kernel
query_succ
=
(
query_succ
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
query_inter
=
(
query_inter
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
outputs_list
=
[]
softmax_lses_list
=
[]
# intra-attention
intra_output
,
intra_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables_intra
,
decode_meta
.
seq_lens_intra
,
softmax_scale
,
causal
=
False
,
)
)
outputs_list
.
append
(
intra_output
)
softmax_lses_list
.
append
(
intra_softmax_lse
)
# succ-attention
if
decode_meta
.
max_seq_len_succ
:
succ_output
,
succ_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query_succ
,
key_cache
,
value_cache
,
decode_meta
.
block_tables_succ
,
decode_meta
.
seq_lens_succ
,
softmax_scale
,
causal
=
False
,
)
)
outputs_list
.
append
(
succ_output
)
softmax_lses_list
.
append
(
succ_softmax_lse
)
# inter-attention
if
decode_meta
.
max_seq_len_inter
:
inter_output
,
inter_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query_inter
,
key_cache
,
value_cache
,
block_table
[:,
:
decode_meta
.
max_seq_len_inter
],
decode_meta
.
seq_lens_inter
,
softmax_scale
,
causal
=
False
,
)
)
outputs_list
.
append
(
inter_output
)
softmax_lses_list
.
append
(
inter_softmax_lse
)
outputs
=
torch
.
stack
(
outputs_list
,
dim
=
0
)
del
outputs_list
softmax_lses
=
torch
.
stack
(
softmax_lses_list
,
dim
=
0
).
to
(
torch
.
float32
)
del
softmax_lses_list
max_logits
=
torch
.
max
(
softmax_lses
,
dim
=
0
).
values
stable_logits
=
softmax_lses
-
max_logits
.
unsqueeze
(
0
)
lse_s
=
torch
.
exp
(
stable_logits
).
detach
()
lse_sum
=
torch
.
sum
(
lse_s
,
dim
=
0
)
lse_s
/=
lse_sum
outputs
*=
lse_s
.
unsqueeze
(
-
1
).
transpose
(
2
,
3
)
return
outputs
.
sum
(
0
)
def
_dual_chunk_flash_attn_decoding_with_exp_sums
(
self
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
,
):
out
,
softmax_lse
,
*
rest_expand
=
flash_attn_with_kvcache
(
q
=
query
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
return_softmax_lse
=
True
,
)
mask
=
cache_seqlens
==
0
out
[
mask
]
=
0
softmax_lse
[
mask
]
=
-
float
(
"inf"
)
return
out
,
softmax_lse
def
_vertical_slash_sparse_attention
(
query
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_CTX, D_HEAD]
key
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
value
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
v_idx
:
torch
.
Tensor
,
# [BATCH, N_HEADS, NNZ_V]
s_idx
:
torch
.
Tensor
,
# [BATCH, N_HEADS, NNZ_S]
softmax_scale
:
float
,
causal
:
bool
=
True
,
stage
:
str
=
"intra"
,
block_size_M
:
int
=
64
,
block_size_N
:
int
=
64
,
vertical_indices_count
:
torch
.
Tensor
=
None
,
# [N_HEADS,]
slash_indices_count
:
torch
.
Tensor
=
None
,
):
if
stage
==
"intra"
:
assert
causal
else
:
assert
not
causal
batch_size
,
num_heads
,
context_size
,
head_dim
=
query
.
shape
_
,
_
,
kv_seq_len
,
_
=
key
.
shape
if
head_dim
not
in
[
16
,
32
,
64
,
128
,
256
,
512
]:
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
query
=
F
.
pad
(
query
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
key
=
F
.
pad
(
key
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
value
=
F
.
pad
(
value
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
v_idx
=
(
v_idx
.
to
(
torch
.
int32
)
.
reshape
((
batch_size
,
num_heads
,
-
1
))
.
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
)
s_idx
=
(
s_idx
.
to
(
torch
.
int32
)
.
reshape
((
batch_size
,
num_heads
,
-
1
))
.
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
)
q_seqlens
=
torch
.
tensor
([
context_size
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
kv_seqlens
=
torch
.
tensor
([
kv_seq_len
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
if
vertical_indices_count
is
not
None
and
slash_indices_count
is
not
None
:
(
block_count
,
block_offset
,
column_count
,
column_index
,
)
=
convert_vertical_slash_indexes_mergehead
(
q_seqlens
,
kv_seqlens
,
v_idx
,
s_idx
,
vertical_indices_count
,
slash_indices_count
,
context_size
,
block_size_M
,
block_size_N
,
causal
,
)
else
:
(
block_count
,
block_offset
,
column_count
,
column_index
,
)
=
convert_vertical_slash_indexes
(
q_seqlens
,
kv_seqlens
,
v_idx
,
s_idx
,
context_size
,
block_size_M
,
block_size_N
,
causal
,
)
q
=
query
.
transpose
(
1
,
2
).
contiguous
()
k
=
key
.
transpose
(
1
,
2
).
contiguous
()
v
=
value
.
transpose
(
1
,
2
).
contiguous
()
out
,
lse
=
sparse_attn_func
(
q
,
k
,
v
,
block_count
,
block_offset
,
column_count
,
column_index
,
causal
=
causal
,
softmax_scale
=
softmax_scale
,
return_softmax_lse
=
True
,
)
out
=
out
.
transpose
(
1
,
2
).
contiguous
()
softmax_lse
=
lse
.
reshape
(
*
lse
.
shape
,
1
)
return
(
out
[...,
:
context_size
,
:
head_dim
],
softmax_lse
[...,
:
context_size
,
:])
def
_sum_all_diagonal_matrix
(
mat
:
torch
.
tensor
):
h
,
n
,
m
=
mat
.
shape
# Zero matrix used for padding
zero_mat
=
torch
.
zeros
((
h
,
n
,
n
),
device
=
mat
.
device
)
# pads the matrix on left and right
mat_padded
=
torch
.
cat
((
zero_mat
,
mat
,
zero_mat
),
-
1
)
# Change the strides
mat_strided
=
mat_padded
.
as_strided
(
(
1
,
n
,
n
+
m
),
(
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
)
)
# Sums the resulting matrix's columns
sum_diags
=
torch
.
sum
(
mat_strided
,
1
)
return
sum_diags
[:,
1
:]
# drop left bottom corner
def
_get_block
(
block_table
:
torch
.
Tensor
,
block_size
:
int
,
begin
:
int
,
end
:
int
):
begin_block
=
begin
//
block_size
end_block
=
(
end
-
1
)
//
block_size
+
1
return
block_table
[
begin_block
:
end_block
]
python/sglang/srt/layers/rotary_embedding.py
View file @
b7cd7430
...
...
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
)
class
DualChunkRotaryEmbedding
(
CustomOp
):
"""Rotary positional embedding for Dual Chunk Attention."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
chunk_size
:
int
,
local_size
:
int
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
chunk_size
=
chunk_size
self
.
local_size
=
local_size
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
(
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
)
=
(
self
.
_compute_cos_sin_cache
()
)
self
.
register_buffer
(
"cos_sin_q_cache"
,
q_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_cache"
,
qc_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_k_cache"
,
k_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_no_clamp_cache"
,
qc_no_clamp_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_q_inter_cache"
,
q_inter_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
q_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
qc_t
=
(
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
).
clamp
(
max
=
self
.
chunk_size
)
k_t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
%
chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
self
.
chunk_size
q_freqs
=
torch
.
outer
(
q_t
,
inv_freq
)
qc_freqs
=
torch
.
outer
(
qc_t
,
inv_freq
)
k_freqs
=
torch
.
outer
(
k_t
,
inv_freq
)
qc_no_clamp_freqs
=
torch
.
outer
(
qc_no_clamp_t
,
inv_freq
)
q_inter_freqs
=
torch
.
outer
(
q_inter_t
,
inv_freq
)
q_cos
=
q_freqs
.
cos
()
q_sin
=
q_freqs
.
sin
()
qc_cos
=
qc_freqs
.
cos
()
qc_sin
=
qc_freqs
.
sin
()
k_cos
=
k_freqs
.
cos
()
k_sin
=
k_freqs
.
sin
()
qc_no_clamp_cos
=
qc_no_clamp_freqs
.
cos
()
qc_no_clamp_sin
=
qc_no_clamp_freqs
.
sin
()
q_inter_cos
=
q_inter_freqs
.
cos
()
q_inter_sin
=
q_inter_freqs
.
sin
()
q_cache
=
torch
.
cat
((
q_cos
,
q_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_cache
=
torch
.
cat
((
qc_cos
,
qc_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
k_cache
=
torch
.
cat
((
k_cos
,
k_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_no_clamp_cache
=
torch
.
cat
((
qc_no_clamp_cos
,
qc_no_clamp_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
q_inter_cache
=
torch
.
cat
((
q_inter_cos
,
q_inter_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
return
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
else
:
query_pass
=
None
key_pass
=
None
positions_with_offsets
=
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
)
key
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_k_cache
[
positions_with_offsets
],
key_rot
,
key_pass
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
query
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_succ
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
chunk_len
-
1
].
repeat
(
positions
.
shape
[
0
],
1
),
query_rot
,
query_pass
,
)
query_succ_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_no_clamp_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_inter_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
# merge query into one tensor to simplify the interfaces
query
=
torch
.
cat
(
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
),
dim
=-
1
,
)
return
query
,
key
def
_apply_rotary_embedding
(
self
,
cos_sin
,
hidden_rot
,
hidden_pass
):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
hidden_rot
=
hidden_rot
*
cos
+
rotate_fn
(
hidden_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
hidden
=
torch
.
cat
((
hidden_rot
,
hidden_pass
),
dim
=-
1
)
else
:
hidden
=
hidden_rot
return
hidden
.
flatten
(
-
2
).
squeeze
(
0
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
s
+=
f
", chunk_size=
{
self
.
chunk_size
}
, local_size=
{
self
.
local_size
}
"
return
s
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
...
@@ -1184,6 +1380,7 @@ def get_rope(
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -1195,6 +1392,17 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
dual_chunk_attention_config
is
not
None
:
dual_chunk_attention_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
!=
"sparse_attention_config"
}
dual_chunk_attention_args
=
tuple
(
dual_chunk_attention_tuple
.
items
())
else
:
dual_chunk_attention_args
=
None
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
...
...
@@ -1204,12 +1412,28 @@ def get_rope(
base
,
is_neox_style
,
rope_scaling_args
,
dual_chunk_attention_args
,
dtype
,
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
dual_chunk_attention_config
is
not
None
:
extra_kwargs
=
{
k
:
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
in
(
"chunk_size"
,
"local_size"
)
}
rotary_emb
=
DualChunkRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
**
extra_kwargs
,
)
elif
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b7cd7430
...
...
@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# The sum of all sequence lengths
seq_lens_sum
:
int
=
None
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int32
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
...
...
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[
len
(
r
.
fill_ids
)
for
r
in
reqs
]
orig_seq_lens
=
[
max
(
len
(
r
.
fill_ids
),
len
(
r
.
origin_input_ids
))
for
r
in
reqs
]
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
...
...
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
...
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
out_cache_loc
=
out_cache_loc
self
.
input_embeds
=
(
torch
.
tensor
(
input_embeds
).
to
(
self
.
device
,
non_blocking
=
True
)
...
...
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
seq_lens_sum
=
0
...
...
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
orig_seq_lens
=
self
.
orig_seq_lens
+
1
else
:
# A faster in-place version
self
.
seq_lens
.
add_
(
1
)
self
.
orig_seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
# free memory
...
...
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
orig_seq_lens
=
self
.
orig_seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
output_ids
=
self
.
output_ids
[
keep_indices_device
]
...
...
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
orig_seq_lens
=
torch
.
cat
([
self
.
orig_seq_lens
,
other
.
orig_seq_lens
])
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
if
self
.
output_ids
is
not
None
:
...
...
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
orig_seq_lens
=
self
.
orig_seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_sum
=
self
.
seq_lens_sum
,
...
...
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info
:
SamplingBatchInfo
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# The input Embeds
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b7cd7430
...
...
@@ -589,6 +589,7 @@ class CudaGraphRunner:
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
orig_seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b7cd7430
...
...
@@ -180,6 +180,9 @@ class ForwardBatch:
# The sum of all sequence lengths
seq_lens_sum
:
int
# The original sequence length without being chunked. Qwen-1M related.
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# Optional seq_lens on cpu
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -321,6 +324,7 @@ class ForwardBatch:
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
orig_seq_lens
=
batch
.
orig_seq_lens
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b7cd7430
...
...
@@ -1467,6 +1467,12 @@ class ModelRunner:
logger
.
info
(
f
"Intel AMX attention backend is enabled."
)
return
IntelAMXAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"dual_chunk_flash_attn"
:
from
sglang.srt.layers.attention.dual_chunk_flashattention_backend
import
(
DualChunkFlashAttentionBackend
,
)
return
DualChunkFlashAttentionBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
...
...
python/sglang/srt/models/qwen2.py
View file @
b7cd7430
...
...
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
32768
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
...
...
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
32768
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
self
.
mlp
=
Qwen2MLP
(
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
b7cd7430
...
...
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings
:
int
=
8192
,
qkv_bias
:
int
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
...
...
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
qkv_bias
=
getattr
(
config
,
"qkv_bias"
,
True
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2MoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
qkv_bias
=
qkv_bias
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
b7cd7430
...
...
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
attention_bias
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
...
...
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
)
rms_norm_eps
=
config
.
rms_norm_eps
attention_bias
=
config
.
attention_bias
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
attention_bias
=
attention_bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
dual_chunk_attention_config
=
dual_chunk_attention_config
,
alt_stream
=
alt_stream
,
)
...
...
python/sglang/srt/server_args.py
View file @
b7cd7430
...
...
@@ -502,6 +502,20 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels
self
.
dtype
=
"bfloat16"
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
logger
.
warning
(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Cuda graph is disabled because of using dual chunk flash attention backend"
)
self
.
enable_mixed_chunk
=
False
self
.
disable_cuda_graph
=
True
self
.
disable_radix_cache
=
True
# Set page size
if
self
.
page_size
is
None
:
self
.
page_size
=
1
...
...
@@ -1337,6 +1351,7 @@ class ServerArgs:
"triton"
,
"trtllm_mla"
,
"trtllm_mha"
,
"dual_chunk_flash_attn"
,
],
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
...
...
python/sglang/srt/two_batch_overlap.py
View file @
b7cd7430
...
...
@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
"padded_static_len"
,
"mrope_positions"
,
# only used by qwen2-vl, thus not care
"split_index"
,
# for split prefill
"orig_seq_lens"
,
# only used by qwen-1m, thus not care
]:
output_dict
[
key
]
=
getattr
(
batch
,
key
)
if
not
batch
.
forward_mode
.
is_target_verify
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment