Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7cd7430
Unverified
Commit
b7cd7430
authored
Aug 07, 2025
by
PGFLMG
Committed by
GitHub
Aug 06, 2025
Browse files
[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
parent
a69b6370
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2121 additions
and
4 deletions
+2121
-4
examples/runtime/engine/offline_batch_inference_qwen_1m.py
examples/runtime/engine/offline_batch_inference_qwen_1m.py
+74
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+28
-0
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+3
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+30
-3
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+1700
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+225
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+6
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+6
-0
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+6
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-0
No files found.
examples/runtime/engine/offline_batch_inference_qwen_1m.py
0 → 100644
View file @
b7cd7430
"""
Usage:
python3 offline_batch_inference.py
"""
from
urllib.request
import
urlopen
import
sglang
as
sgl
def
load_prompt
()
->
str
:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with
urlopen
(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/64k.txt"
,
timeout
=
5
,
)
as
response
:
prompt
=
response
.
read
().
decode
(
"utf-8"
)
return
prompt
# Processing the prompt.
def
process_requests
(
llm
:
sgl
.
Engine
,
prompts
:
list
[
str
])
->
None
:
# Create a sampling params object.
sampling_params
=
{
"temperature"
:
0.7
,
"top_p"
:
0.8
,
"top_k"
:
20
,
"repetition_penalty"
:
1.05
,
"max_new_tokens"
:
256
,
}
# Generate texts from the prompts.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt_token_ids
=
output
[
"meta_info"
][
"prompt_tokens"
]
generated_text
=
output
[
"text"
]
print
(
f
"Prompt length:
{
prompt_token_ids
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
# Create an LLM.
def
initialize_engine
()
->
sgl
.
Engine
:
llm
=
sgl
.
Engine
(
model_path
=
"Qwen/Qwen2.5-7B-Instruct-1M"
,
context_length
=
1048576
,
page_size
=
256
,
attention_backend
=
"dual_chunk_flash_attn"
,
tp_size
=
4
,
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
,
enable_torch_compile
=
False
,
chunked_prefill_size
=
131072
,
mem_fraction_static
=
0.6
,
log_level
=
"DEBUG"
,
)
return
llm
def
main
():
llm
=
initialize_engine
()
prompt
=
load_prompt
()
process_requests
(
llm
,
[
prompt
])
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/configs/model_config.py
View file @
b7cd7430
...
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
get_context_length
,
get_context_length
,
get_generation_config
,
get_generation_config
,
get_hf_text_config
,
get_hf_text_config
,
get_sparse_attention_config
,
)
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -270,6 +271,9 @@ class ModelConfig:
...
@@ -270,6 +271,9 @@ class ModelConfig:
# Verify quantization
# Verify quantization
self
.
_verify_quantization
()
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
# Cache attributes
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
...
@@ -297,6 +301,13 @@ class ModelConfig:
...
@@ -297,6 +301,13 @@ class ModelConfig:
**
kwargs
,
**
kwargs
,
)
)
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
num_attention_heads
def
get_num_attention_heads
(
self
,
tensor_parallel_size
)
->
int
:
total_num_attention_heads
=
self
.
num_attention_heads
return
max
(
1
,
total_num_attention_heads
//
tensor_parallel_size
)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
"""Returns the total number of KV heads."""
...
@@ -484,6 +495,23 @@ class ModelConfig:
...
@@ -484,6 +495,23 @@ class ModelConfig:
self
.
quantization
,
self
.
quantization
,
)
)
def
_verify_dual_chunk_attention_config
(
self
)
->
None
:
if
hasattr
(
self
.
hf_config
,
"dual_chunk_attention_config"
):
# Try loading the sparse attention config
sparse_attn_config
=
get_sparse_attention_config
(
self
.
model_path
)
if
not
sparse_attn_config
:
return
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_config"
]
=
(
sparse_attn_config
)
if
(
"sparse_attention_enabled"
not
in
self
.
hf_config
.
dual_chunk_attention_config
):
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_enabled"
]
=
True
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
if
eos_ids
is
not
None
:
if
eos_ids
is
not
None
:
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
b7cd7430
...
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
seq_lens_sum
=
sum
(
seq_lens
)
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
b7cd7430
...
@@ -14,10 +14,11 @@
...
@@ -14,10 +14,11 @@
"""Utilities for Huggingface Transformers."""
"""Utilities for Huggingface Transformers."""
import
contextlib
import
contextlib
import
json
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
import
torch
import
torch
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
...
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
...
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig
.
register
(
name
,
cls
)
AutoConfig
.
register
(
name
,
cls
)
def
download_from_hf
(
model_path
:
str
):
def
download_from_hf
(
model_path
:
str
,
allow_patterns
:
Optional
[
Union
[
str
,
list
]]
=
None
,
):
if
os
.
path
.
exists
(
model_path
):
if
os
.
path
.
exists
(
model_path
):
return
model_path
return
model_path
return
snapshot_download
(
model_path
,
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
])
if
not
allow_patterns
:
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
]
return
snapshot_download
(
model_path
,
allow_patterns
=
allow_patterns
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
def
get_hf_text_config
(
config
:
PretrainedConfig
):
...
@@ -171,6 +178,26 @@ def get_generation_config(
...
@@ -171,6 +178,26 @@ def get_generation_config(
return
None
return
None
# Qwen-1M related
def
get_sparse_attention_config
(
model
:
str
,
sparse_attention_config_filename
:
str
=
"sparse_attention_config.json"
,
)
->
Dict
[
str
,
Any
]:
is_local
=
os
.
path
.
isdir
(
model
)
if
not
is_local
:
# Download the config files.
model
=
download_from_hf
(
model
,
allow_patterns
=
[
"*.json"
])
config_file
=
os
.
path
.
join
(
model
,
sparse_attention_config_filename
)
if
not
os
.
path
.
exists
(
config_file
):
return
{}
# Load the sparse attention config.
with
open
(
config_file
)
as
f
:
config
=
json
.
load
(
f
)
return
config
# Models don't use the same configuration key for determining the maximum
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# NOTE: The ordering here is important. Some models have two of these and we
...
...
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
0 → 100644
View file @
b7cd7430
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with Dual chunk flash attention and sparse attention.
"""
import
functools
import
logging
import
math
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.sparse_flash_attn
import
(
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes_mergehead
,
sparse_attn_func
,
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
DualChunkFlashAttentionMetadata
:
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
=
None
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_seq_len
:
int
=
None
# (batch_size,). The orig sequence length per sequence.
orig_seq_lens
:
Optional
[
List
[
int
]]
=
None
# orig_seq_lens stored as a tensor.
orig_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Length scaling factor
scaling_factor
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for intra attention.
seq_lens_intra
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for intra attention.
max_seq_len_intra
:
Optional
[
int
]
=
None
# (batch_size, num_blocks). Block table for intra attention.
block_tables_intra
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for succ attention.
seq_lens_succ
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for succ attention.
max_seq_len_succ
:
Optional
[
int
]
=
None
# (batch_size, num_blocks). Block table for succ attention.
block_tables_succ
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for inter attention.
seq_lens_inter
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for inter attention.
max_seq_len_inter
:
Optional
[
int
]
=
None
class
DualChunkFlashAttentionBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
"ModelRunner"
,
)
->
None
:
self
.
forward_metadata
:
FlashAttentionMetadata
=
None
self
.
device
=
model_runner
.
device
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
num_heads
=
model_runner
.
model_config
.
get_num_attention_heads
(
model_runner
.
server_args
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
server_args
.
tp_size
)
self
.
head_size
=
model_runner
.
model_config
.
head_dim
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
self
.
kv_cache_dtype_str
=
model_runner
.
server_args
.
kv_cache_dtype
self
.
page_size
=
model_runner
.
page_size
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
dual_chunk_attention_config
=
getattr
(
model_runner
.
model_config
.
hf_config
,
"dual_chunk_attention_config"
,
None
)
assert
dual_chunk_attention_config
is
not
None
self
.
chunk_size
=
dual_chunk_attention_config
.
get
(
"chunk_size"
,
8192
)
self
.
local_size
=
dual_chunk_attention_config
.
get
(
"local_size"
,
1024
)
self
.
original_max_position_embeddings
=
dual_chunk_attention_config
.
get
(
"original_max_position_embeddings"
,
0
)
self
.
sparse_attention_config
=
dual_chunk_attention_config
.
get
(
"sparse_attention_config"
,
None
)
if
not
self
.
sparse_attention_config
:
logger
.
warning_once
(
"Sparse attention will not be enabled as "
"sparse attention config is not provided."
)
self
.
sparse_attention_enabled
=
dual_chunk_attention_config
.
get
(
"sparse_attention_enabled"
,
self
.
sparse_attention_config
is
not
None
)
self
.
sparse_attention_threshold
=
dual_chunk_attention_config
.
get
(
"sparse_attention_threshold"
,
32768
)
self
.
sparse_attention_last_q
=
dual_chunk_attention_config
.
get
(
"sparse_attention_last_q"
,
64
)
self
.
dual_chunk_attention_config
=
dual_chunk_attention_config
if
self
.
sparse_attention_enabled
:
self
.
arange
=
torch
.
arange
(
self
.
sparse_attention_last_q
,
device
=
"cuda"
)
self
.
last_q_mask
=
(
self
.
arange
[
None
,
None
,
:,
None
]
>=
self
.
arange
[
None
,
None
,
None
,
:]
)
@
functools
.
lru_cache
()
def
get_sparse_attention_config
(
self
,
layer_idx
)
->
List
[
Dict
[
str
,
Any
]]:
layer_sparse_attention_config
=
{
int
(
i
):
j
for
i
,
j
in
self
.
sparse_attention_config
[
layer_idx
].
items
()
}
start_head
=
self
.
num_heads
*
get_tensor_model_parallel_rank
()
end_head
=
start_head
+
self
.
num_heads
return
[
layer_sparse_attention_config
[
i
]
for
i
in
range
(
start_head
,
end_head
)]
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
forward_mode
:
ForwardMode
=
forward_batch
.
forward_mode
assert
forward_mode
.
is_prefill
()
or
forward_mode
.
is_decode
()
batch_size
=
forward_batch
.
batch_size
metadata
=
DualChunkFlashAttentionMetadata
()
metadata
.
seq_lens_tensor
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
metadata
.
seq_lens
=
forward_batch
.
seq_lens
.
tolist
()
metadata
.
max_seq_len
=
forward_batch
.
seq_lens
.
max
().
item
()
metadata
.
orig_seq_lens_tensor
=
forward_batch
.
orig_seq_lens
metadata
.
orig_seq_lens
=
forward_batch
.
orig_seq_lens
.
tolist
()
metadata
.
block_tables
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len
]
# Convert the block table to a strided format.
if
self
.
page_size
>
1
:
strided_indices
=
torch
.
arange
(
0
,
metadata
.
block_tables
.
shape
[
1
],
self
.
page_size
,
device
=
self
.
device
)
metadata
.
block_tables
=
(
metadata
.
block_tables
[:,
strided_indices
]
//
self
.
page_size
)
metadata
.
query_start_loc
=
torch
.
zeros
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
metadata
.
seq_lens_tensor
.
device
)
if
forward_mode
.
is_prefill
():
metadata
.
query_start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
.
to
(
torch
.
int32
),
dim
=
0
,
dtype
=
torch
.
int32
)
else
:
metadata
.
query_start_loc
[
1
:]
=
torch
.
cumsum
(
torch
.
arange
(
batch_size
,
dtype
=
metadata
.
query_start_loc
.
dtype
,
device
=
metadata
.
query_start_loc
.
device
,
),
dim
=
0
,
dtype
=
torch
.
int32
,
)
metadata
.
seq_start_loc
=
torch
.
zeros
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
metadata
.
seq_lens_tensor
.
device
)
metadata
.
seq_start_loc
[
1
:]
=
torch
.
cumsum
(
metadata
.
seq_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
)
if
self
.
original_max_position_embeddings
>
0
:
if
forward_mode
.
is_prefill
():
metadata
.
scaling_factor
=
(
0.1
*
torch
.
log
(
metadata
.
orig_seq_lens_tensor
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
else
:
metadata
.
scaling_factor
=
(
0.1
*
torch
.
log
(
metadata
.
orig_seq_lens_tensor
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
if
forward_mode
.
is_decode
():
cache_seq_lens
=
metadata
.
orig_seq_lens_tensor
chunk_len
=
self
.
chunk_size
-
self
.
local_size
chunk_num_curr
=
(
cache_seq_lens
-
1
)
//
chunk_len
seq_lens_intra
=
cache_seq_lens
-
chunk_num_curr
*
chunk_len
max_seq_len_intra
=
seq_lens_intra
.
max
().
item
()
metadata
.
seq_lens_intra
=
seq_lens_intra
metadata
.
max_seq_len_intra
=
max_seq_len_intra
block_tables_intra
=
torch
.
zeros
(
batch_size
,
(
max_seq_len_intra
-
1
)
//
self
.
page_size
+
1
,
dtype
=
metadata
.
block_tables
.
dtype
,
device
=
metadata
.
block_tables
.
device
,
)
for
i
in
range
(
batch_size
):
st
=
chunk_num_curr
[
i
]
*
chunk_len
//
self
.
page_size
ed
=
min
(
st
+
(
max_seq_len_intra
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
block_tables_intra
[
i
,
:
ed
-
st
]
=
metadata
.
block_tables
[
i
,
st
:
ed
]
metadata
.
block_tables_intra
=
block_tables_intra
metadata
.
seq_lens_succ
=
(
chunk_num_curr
-
(
chunk_num_curr
-
1
).
clip
(
min
=
0
)
)
*
chunk_len
metadata
.
max_seq_len_succ
=
metadata
.
seq_lens_succ
.
max
().
item
()
if
metadata
.
max_seq_len_succ
:
block_tables_succ
=
torch
.
zeros
(
batch_size
,
(
metadata
.
max_seq_len_succ
-
1
)
//
self
.
page_size
+
1
,
dtype
=
metadata
.
block_tables
.
dtype
,
device
=
metadata
.
block_tables
.
device
,
)
for
i
in
range
(
batch_size
):
start
=
(
(
chunk_num_curr
[
i
]
-
1
).
clip
(
min
=
0
)
*
chunk_len
//
self
.
page_size
)
end
=
min
(
start
+
(
metadata
.
max_seq_len_succ
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
block_tables_succ
[
i
,
:
end
-
start
]
=
metadata
.
block_tables
[
i
,
start
:
end
]
metadata
.
block_tables_succ
=
block_tables_succ
metadata
.
seq_lens_inter
=
(
chunk_num_curr
-
1
).
clip
(
min
=
0
)
*
chunk_len
metadata
.
max_seq_len_inter
=
metadata
.
seq_lens_inter
.
max
().
item
()
self
.
forward_metadata
=
metadata
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
# Use precomputed metadata across all layers
metadata
=
self
.
forward_metadata
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
)
=
torch
.
split
(
q
,
q
.
shape
[
-
1
]
//
5
,
dim
=-
1
)
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ
=
query_succ
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter
=
query_inter
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ_critical
=
query_succ_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter_critical
=
query_inter_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
# apply DCA scaling
if
self
.
original_max_position_embeddings
>
0
:
assert
metadata
.
scaling_factor
is
not
None
assert
metadata
.
query_start_loc
is
not
None
assert
metadata
.
orig_seq_lens
is
not
None
current_start
=
0
query_start_loc_cpu
=
metadata
.
query_start_loc
.
cpu
()
for
i
in
range
(
len
(
metadata
.
orig_seq_lens
)):
current_end
=
(
current_start
+
(
query_start_loc_cpu
[
i
+
1
]
-
query_start_loc_cpu
[
i
]).
item
()
)
key
[
current_start
:
current_end
].
mul_
(
metadata
.
scaling_factor
[
i
])
current_start
=
current_end
assert
current_end
<=
self
.
max_context_len
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
if
key
is
not
None
and
value
is
not
None
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
key
,
value
,
layer
.
k_scale
,
layer
.
v_scale
,
)
if
not
save_kv_cache
:
# profile run
o
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
metadata
.
seq_start_loc
,
cu_seqlens_k
=
metadata
.
seq_start_loc
,
max_seqlen_q
=
metadata
.
max_seq_len
,
max_seqlen_k
=
metadata
.
max_seq_len
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
)
else
:
# prefill/chunked-prefill
# get per layer sparse attention config
if
self
.
sparse_attention_enabled
:
self
.
layer_sparse_attention_config
=
self
.
get_sparse_attention_config
(
layer
.
layer_id
)
assert
metadata
.
orig_seq_lens
is
not
None
o
=
self
.
_dual_chunk_flash_attn_prefill
(
q
=
query
,
q_succ
=
query_succ
,
q_inter
=
query_inter
,
q_succ_critical
=
query_succ_critical
,
q_inter_critical
=
query_inter_critical
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
metadata
.
query_start_loc
,
cu_seqlens_k
=
metadata
.
seq_start_loc
,
orig_seq_lens
=
metadata
.
orig_seq_lens
,
scaling_factor
=
metadata
.
scaling_factor
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
(
-
1
,
-
1
),
block_table
=
metadata
.
block_tables
,
chunk_size
=
self
.
chunk_size
,
local_size
=
self
.
local_size
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
)
->
torch
.
Tensor
:
# Use precomputed metadata across all layers
metadata
=
self
.
forward_metadata
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
)
=
torch
.
split
(
q
,
q
.
shape
[
-
1
]
//
5
,
dim
=-
1
)
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ
=
query_succ
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter
=
query_inter
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ_critical
=
query_succ_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter_critical
=
query_inter_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
if
key
is
not
None
and
value
is
not
None
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
key
,
value
,
layer
.
k_scale
,
layer
.
v_scale
,
)
# apply DCA scaling
if
self
.
original_max_position_embeddings
>
0
:
assert
metadata
.
scaling_factor
is
not
None
scaling_factor
=
metadata
.
scaling_factor
key
.
mul_
(
scaling_factor
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
))
o
=
self
.
_dual_chunk_flash_attn_decoding
(
query
.
unsqueeze
(
1
),
query_succ
.
unsqueeze
(
1
),
query_inter
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
block_table
=
metadata
.
block_tables
,
cache_seqlens
=
metadata
.
seq_lens_tensor
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
chunk_size
=
self
.
chunk_size
,
local_size
=
self
.
local_size
,
original_max_position_embeddings
=
self
.
original_max_position_embeddings
,
decode_meta
=
metadata
,
).
squeeze
(
1
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self
.
decode_metadata
=
{
"seq_lens_tensor"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"orig_seq_lens_tensor"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"scaling_factor"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
float32
,
device
=
self
.
device
),
"block_tables"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
-
1
)
//
self
.
page_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"block_tables_intra"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
-
1
)
//
self
.
page_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"seq_lens_intra"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"block_tables_succ"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
-
1
)
//
self
.
page_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"seq_lens_succ"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"seq_lens_inter"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
}
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
None
],
):
metadata
=
DualChunkFlashAttentionMetadata
()
if
forward_mode
.
is_decode_or_idle
():
if
self
.
original_max_position_embeddings
>
0
:
metadata
.
scaling_factor
=
self
.
decode_metadata
[
"scaling_factor"
][:
bs
]
metadata
.
seq_lens_tensor
=
self
.
decode_metadata
[
"seq_lens_tensor"
][:
bs
]
metadata
.
orig_seq_lens_tensor
=
self
.
decode_metadata
[
"orig_seq_lens_tensor"
][:
bs
]
metadata
.
max_seq_len
=
self
.
max_context_len
metadata
.
block_tables
=
self
.
decode_metadata
[
"block_tables"
][
req_pool_indices
,
:
]
# intra
metadata
.
max_seq_len_intra
=
self
.
max_context_len
metadata
.
seq_lens_intra
=
self
.
decode_metadata
[
"seq_lens_intra"
][:
bs
]
metadata
.
block_tables_intra
=
self
.
decode_metadata
[
"block_tables_intra"
][
:
bs
,
:
]
# succ
metadata
.
seq_lens_succ
=
self
.
decode_metadata
[
"seq_lens_succ"
][:
bs
]
metadata
.
max_seq_len_succ
=
self
.
max_context_len
metadata
.
block_tables_succ
=
self
.
decode_metadata
[
"block_tables_succ"
][
:
bs
,
:
]
metadata
.
seq_lens_inter
=
self
.
decode_metadata
[
"seq_lens_inter"
][:
bs
]
metadata
.
max_seq_len_inter
=
self
.
max_context_len
self
.
decode_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
None
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
out_cache_loc
:
torch
.
Tensor
=
None
,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert
forward_mode
.
is_decode
()
seq_lens
=
seq_lens
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
metadata
=
self
.
decode_metadata
[
bs
]
metadata
.
seq_lens_tensor
.
copy_
(
seq_lens
.
to
(
torch
.
int32
))
metadata
.
seq_lens
=
seq_lens
.
tolist
()
metadata
.
max_seq_len
=
seq_lens
.
max
().
item
()
metadata
.
orig_seq_lens_tensor
.
copy_
(
seq_lens
)
metadata
.
orig_seq_lens
=
seq_lens
.
tolist
()
block_tables
=
self
.
req_to_token
[
req_pool_indices
,
:
metadata
.
max_seq_len
]
# Convert the block table to a strided format.
if
self
.
page_size
>
1
:
strided_indices
=
torch
.
arange
(
0
,
block_tables
.
shape
[
1
],
self
.
page_size
,
device
=
self
.
device
)
block_tables
=
block_tables
[:,
strided_indices
]
//
self
.
page_size
metadata
.
block_tables
.
fill_
(
0
)
metadata
.
block_tables
[:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]].
copy_
(
block_tables
)
if
self
.
original_max_position_embeddings
>
0
:
scaling_factor
=
(
0.1
*
torch
.
log
(
metadata
.
orig_seq_lens_tensor
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
metadata
.
scaling_factor
.
copy_
(
scaling_factor
)
cache_seq_lens
=
metadata
.
orig_seq_lens_tensor
chunk_len
=
self
.
chunk_size
-
self
.
local_size
chunk_num_curr
=
(
cache_seq_lens
-
1
)
//
chunk_len
seq_lens_intra
=
cache_seq_lens
-
chunk_num_curr
*
chunk_len
max_seq_len_intra
=
seq_lens_intra
.
max
().
item
()
metadata
.
seq_lens_intra
.
copy_
(
seq_lens_intra
)
metadata
.
max_seq_len_intra
=
max_seq_len_intra
metadata
.
block_tables_intra
.
fill_
(
0
)
for
i
in
range
(
bs
):
st
=
chunk_num_curr
[
i
]
*
chunk_len
//
self
.
page_size
ed
=
min
(
st
+
(
max_seq_len_intra
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
metadata
.
block_tables_intra
[
i
,
:
ed
-
st
]
=
metadata
.
block_tables
[
i
,
st
:
ed
]
seq_lens_succ
=
(
chunk_num_curr
-
(
chunk_num_curr
-
1
).
clip
(
min
=
0
))
*
chunk_len
metadata
.
seq_lens_succ
.
copy_
(
seq_lens_succ
)
metadata
.
max_seq_len_succ
=
metadata
.
seq_lens_succ
.
max
().
item
()
if
metadata
.
max_seq_len_succ
:
metadata
.
block_tables_succ
.
fill_
(
0
)
for
i
in
range
(
bs
):
start
=
(
(
chunk_num_curr
[
i
]
-
1
).
clip
(
min
=
0
)
*
chunk_len
//
self
.
page_size
)
end
=
min
(
start
+
(
metadata
.
max_seq_len_succ
-
1
)
//
self
.
page_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
page_size
+
1
,
)
metadata
.
block_tables_succ
[
i
,
:
end
-
start
]
=
metadata
.
block_tables
[
i
,
start
:
end
]
seq_lens_inter
=
(
chunk_num_curr
-
1
).
clip
(
min
=
0
)
*
chunk_len
metadata
.
seq_lens_inter
.
copy_
(
seq_lens_inter
)
metadata
.
max_seq_len_inter
=
metadata
.
seq_lens_inter
.
max
().
item
()
self
.
forward_metadata
=
metadata
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for sequence length in CUDA graph."""
return
1
def
_dual_chunk_flash_attn_prefill
(
self
,
q
,
q_succ
,
q_inter
,
q_succ_critical
,
q_inter_critical
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
orig_seq_lens
:
List
[
int
],
scaling_factor
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
Optional
[
bool
]
=
True
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
int
=
8192
,
local_size
:
int
=
1024
,
):
if
not
causal
:
raise
ValueError
(
"Dual Chunk Attention does not support causal=False"
)
if
window_size
!=
(
-
1
,
-
1
):
raise
ValueError
(
"Dual Chunk Attention does not support window_size"
)
cu_seqlens_q_cpu
=
cu_seqlens_q
.
cpu
().
tolist
()
cu_seqlens_k_cpu
=
cu_seqlens_k
.
cpu
().
tolist
()
all_outputs
=
[]
for
i
in
range
(
0
,
len
(
cu_seqlens_q_cpu
)
-
1
):
qs
=
cu_seqlens_q_cpu
[
i
]
qe
=
cu_seqlens_q_cpu
[
i
:
i
+
2
][
-
1
]
ks
=
cu_seqlens_k_cpu
[
i
]
ke
=
cu_seqlens_k_cpu
[
i
:
i
+
2
][
-
1
]
current_q
=
q
[
qs
:
qe
]
current_q_succ
=
q_succ
[
qs
:
qe
]
current_q_inter
=
q_inter
[
qs
:
qe
]
current_q_succ_critical
=
q_succ_critical
[
qs
:
qe
]
current_q_inter_critical
=
q_inter_critical
[
qs
:
qe
]
if
block_table
is
None
:
current_k
=
k
[
ks
:
ke
]
current_v
=
v
[
ks
:
ke
]
current_block_table
=
None
current_orig_seq_len
=
orig_seq_lens
[
i
]
else
:
current_block_table
=
block_table
[
i
]
current_orig_seq_len
=
orig_seq_lens
[
i
]
current_k
=
k
current_v
=
v
sparse_attn_enabled
=
(
self
.
sparse_attention_enabled
and
current_orig_seq_len
>
self
.
sparse_attention_threshold
)
if
current_q
.
shape
[
0
]
==
0
:
continue
if
current_k
.
shape
[
0
]
==
0
:
all_outputs
.
append
(
torch
.
zeros
(
(
current_q
.
shape
[
0
],
current_q
.
shape
[
1
],
v
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
q
.
dtype
,
)
)
continue
current_output
=
torch
.
empty_like
(
current_q
)
group_size
=
int
(
current_q
.
size
(
-
2
)
/
current_k
.
size
(
-
2
))
if
sparse_attn_enabled
:
num_device_q_heads
=
current_q
.
size
(
-
2
)
heads_vertical_size
=
torch
.
empty
(
size
=
(
num_device_q_heads
,),
dtype
=
torch
.
int32
)
heads_slash_size
=
torch
.
empty
(
size
=
(
num_device_q_heads
,),
dtype
=
torch
.
int32
)
for
head_id
in
range
(
current_q
.
size
(
-
2
)):
(
ty
,
vertical_size
,
slash_size
,
_
,
)
=
self
.
layer_sparse_attention_config
[
head_id
]
assert
ty
==
"vertical_and_slash"
,
"only support slash mode"
if
vertical_size
==
30
:
vertical_size
+=
100
heads_vertical_size
[
head_id
]
=
vertical_size
heads_slash_size
[
head_id
]
=
slash_size
current_output
=
self
.
_dual_chunk_flash_attn_prefill_func
(
current_q
,
# allheads
current_q_succ
,
current_q_inter
,
current_q_succ_critical
,
current_q_inter_critical
,
current_k
,
current_v
,
current_block_table
,
softmax_scale
,
chunk_size
,
local_size
,
scaling_factor
[
i
].
item
(),
ke
-
ks
,
sparse_attn_enabled
=
sparse_attn_enabled
,
heads_vertical_size
=
heads_vertical_size
,
heads_slash_size
=
heads_slash_size
,
group_size
=
group_size
,
)
else
:
for
head_id
in
range
(
current_q
.
size
(
-
2
)):
# (seq_len, num_heads, head_size)
current_q_head
=
current_q
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_succ_head
=
current_q_succ
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_inter_head
=
current_q_inter
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_succ_head_critical
=
current_q_succ_critical
[
:,
head_id
,
:
].
unsqueeze
(
1
)
current_q_inter_head_critical
=
current_q_inter_critical
[
:,
head_id
,
:
].
unsqueeze
(
1
)
if
block_table
is
not
None
:
current_k_head
=
current_k
[
...,
head_id
//
group_size
,
:
].
unsqueeze
(
2
)
current_v_head
=
current_v
[
...,
head_id
//
group_size
,
:
].
unsqueeze
(
2
)
else
:
current_k_head
=
current_k
[:,
head_id
,
:].
unsqueeze
(
1
)
current_v_head
=
current_v
[:,
head_id
,
:].
unsqueeze
(
1
)
current_out
=
self
.
_dual_chunk_flash_attn_prefill_func
(
current_q_head
,
current_q_succ_head
,
current_q_inter_head
,
current_q_succ_head_critical
,
current_q_inter_head_critical
,
current_k_head
,
current_v_head
,
current_block_table
,
softmax_scale
,
chunk_size
,
local_size
,
scaling_factor
[
i
].
item
(),
ke
-
ks
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
current_output
[:,
head_id
:
head_id
+
1
,
:]
=
current_out
all_outputs
.
append
(
current_output
)
return
torch
.
cat
(
all_outputs
,
dim
=
0
)
def
_dual_chunk_flash_attn_prefill_func
(
self
,
q
,
q_succ
,
q_inter
,
q_succ_critical
,
q_inter_critical
,
k
,
v
,
block_table
,
softmax_scale
:
float
,
chunk_size
:
int
,
local_size
:
int
,
scaling_factor
:
float
,
k_length
:
int
,
sparse_attn_enabled
:
Optional
[
bool
]
=
True
,
heads_vertical_size
=
None
,
heads_slash_size
=
None
,
group_size
=
None
,
):
flash_results
=
[]
chunk_len
=
chunk_size
-
local_size
if
block_table
is
not
None
:
block_size
=
v
.
shape
[
1
]
if
chunk_len
%
block_size
!=
0
:
raise
ValueError
(
"chunk_len must be divisible by block_size."
)
else
:
block_size
=
1
if
self
.
original_max_position_embeddings
>
0
:
softmax_scale
=
softmax_scale
*
scaling_factor
begin
=
k_length
-
q
.
shape
[
0
]
while
begin
<
k_length
:
flash_per_chunk
=
[]
prev_chunk_end_pos
=
(
begin
//
chunk_len
)
*
chunk_len
next_chunk_end_pos
=
prev_chunk_end_pos
+
chunk_len
end
=
min
(
next_chunk_end_pos
,
k_length
)
qbegin
=
begin
-
(
k_length
-
q
.
shape
[
0
])
qend
=
end
-
(
k_length
-
q
.
shape
[
0
])
qk_chunks
=
[]
q_states_intra
=
q
[
qbegin
:
qend
]
# choose critical token
if
block_table
is
not
None
:
block_tables_intra
=
_get_block
(
block_table
,
block_size
,
prev_chunk_end_pos
,
end
)
k_states_intra
=
k
[
block_tables_intra
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[
:
(
end
-
prev_chunk_end_pos
)
]
v_states_intra
=
v
[
block_tables_intra
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[
:
(
end
-
prev_chunk_end_pos
)
]
else
:
block_tables_intra
=
None
k_states_intra
=
k
[
prev_chunk_end_pos
:
end
]
v_states_intra
=
v
[
prev_chunk_end_pos
:
end
]
if
sparse_attn_enabled
:
last_q_size
=
min
(
qend
-
qbegin
,
self
.
sparse_attention_last_q
)
_
,
num_device_k_heads
,
head_dim
=
k_states_intra
.
shape
k_states_intra
=
(
k_states_intra
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
v_states_intra
=
(
v_states_intra
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
qk_chunks
.
append
(
(
q_states_intra
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_intra
.
permute
(
1
,
2
,
0
)
)
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
q_states_succ
=
q_succ
[
qbegin
:
qend
]
q_states_succ_critical
=
q_succ_critical
[
qbegin
:
qend
]
if
block_table
is
not
None
:
block_tables_succ
=
_get_block
(
block_table
,
block_size
,
prev_chunk_end_pos
-
chunk_len
,
prev_chunk_end_pos
,
)
k_states_succ
=
k
[
block_tables_succ
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[
:
chunk_len
]
v_states_succ
=
v
[
block_tables_succ
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[
:
chunk_len
]
else
:
k_states_succ
=
k
[
prev_chunk_end_pos
-
chunk_len
:
prev_chunk_end_pos
]
v_states_succ
=
v
[
prev_chunk_end_pos
-
chunk_len
:
prev_chunk_end_pos
]
if
sparse_attn_enabled
:
k_states_succ
=
(
k_states_succ
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
v_states_succ
=
(
v_states_succ
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
qk_chunks
.
append
(
(
q_states_succ_critical
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_succ
.
permute
(
1
,
2
,
0
)
)
if
prev_chunk_end_pos
-
chunk_len
*
2
>=
0
:
q_states_inter
=
q_inter
[
qbegin
:
qend
]
q_states_inter_critical
=
q_inter_critical
[
qbegin
:
qend
]
if
block_table
is
not
None
:
block_tables_inter
=
_get_block
(
block_table
,
block_size
,
0
,
prev_chunk_end_pos
-
chunk_len
)
k_states_inter
=
k
[
block_tables_inter
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[
:
(
prev_chunk_end_pos
-
chunk_len
)
]
v_states_inter
=
v
[
block_tables_inter
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[
:
(
prev_chunk_end_pos
-
chunk_len
)
]
else
:
k_states_inter
=
k
[:
prev_chunk_end_pos
-
chunk_len
]
v_states_inter
=
v
[:
prev_chunk_end_pos
-
chunk_len
]
if
sparse_attn_enabled
:
k_states_inter
=
(
k_states_inter
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
v_states_inter
=
(
v_states_inter
.
unsqueeze
(
2
)
.
repeat
(
1
,
1
,
group_size
,
1
)
.
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
)
)
qk_chunks
.
append
(
(
q_states_inter_critical
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_inter
.
permute
(
1
,
2
,
0
)
)
if
sparse_attn_enabled
:
reversed_qk
=
qk_chunks
[::
-
1
]
qk
=
torch
.
cat
(
reversed_qk
,
dim
=-
1
)
qk
[:,
:,
-
last_q_size
:]
=
torch
.
where
(
self
.
last_q_mask
[...,
-
last_q_size
:,
-
last_q_size
:].
to
(
qk
.
device
),
qk
[:,
:,
-
last_q_size
:],
-
torch
.
inf
,
)
qk
=
F
.
softmax
(
qk
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vertical
=
qk
.
sum
(
-
2
,
keepdim
=
True
)
vertical
[...,
:
30
]
=
torch
.
inf
# Avoid sorting by using the min/max ints to fill the indexer
# buffers.
int32_max
=
torch
.
iinfo
(
torch
.
int32
).
max
int32_min
=
torch
.
iinfo
(
torch
.
int32
).
min
n_heads
=
qk
.
size
()[
0
]
max_slash_topk
=
torch
.
max
(
heads_slash_size
).
item
()
max_vertical_topk
=
torch
.
max
(
heads_vertical_size
).
item
()
# store each head's slash topk, vertical topk
vertical
=
vertical
.
reshape
((
n_heads
,
-
1
))
# prevent out of range when prompt size < max_vertical_topk
max_vertical_topk
=
min
(
vertical
.
shape
[
-
1
],
max_vertical_topk
)
vertical_topk_buffer
=
torch
.
topk
(
vertical
,
max_vertical_topk
,
-
1
).
indices
slash_topk_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
max_slash_topk
),
dtype
=
torch
.
int64
,
device
=
qk
.
device
)
for
head_i
in
range
(
n_heads
):
# (nqheads=1, lastq, k_len)
head_score
=
qk
[
head_i
:
head_i
+
1
,
:,
:]
slash_scores
=
_sum_all_diagonal_matrix
(
head_score
)
if
head_score
.
size
(
1
)
!=
1
:
# drop right up corner
slash_scores
=
slash_scores
[...,
:
-
last_q_size
+
1
]
slash_scores
[...,
-
100
:]
=
torch
.
inf
head_slash_size
=
heads_slash_size
[
head_i
]
head_slash_size
=
min
(
head_slash_size
,
vertical
.
size
(
-
1
))
slash_topk
=
torch
.
topk
(
slash_scores
,
head_slash_size
,
-
1
).
indices
# (nheads, max_topk)
slash_topk_buffer
[
head_i
,
:
head_slash_size
]
=
slash_topk
# reset heads topk
heads_slash_size
[
head_i
]
=
head_slash_size
heads_vertical_size
[
head_i
]
=
min
(
heads_vertical_size
[
head_i
],
max_vertical_topk
)
# store
vertical_buffer
=
torch
.
full
(
(
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
slash_buffer
=
torch
.
full
(
(
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
succ_vertical_buffer
=
torch
.
full
(
(
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
succ_slash_buffer
=
torch
.
full
(
(
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
inter_vertical_buffer
=
torch
.
full
(
(
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
inter_slash_buffer
=
torch
.
full
(
(
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
,
)
vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
succ_vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
succ_slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
inter_vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
inter_slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
for
head_i
in
range
(
n_heads
):
vertical_topk
=
vertical_topk_buffer
[
head_i
,
:
heads_vertical_size
[
head_i
]
]
# intra
intra_vertical_indices
=
(
vertical_topk
[
vertical_topk
>=
prev_chunk_end_pos
]
-
prev_chunk_end_pos
)
if
intra_vertical_indices
.
nelement
()
==
0
:
intra_vertical_indices
=
torch
.
cat
(
[
intra_vertical_indices
,
torch
.
arange
(
0
,
k_states_intra
.
size
(
0
),
max
(
1
,
k_states_intra
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
slash_topk
=
slash_topk_buffer
[
head_i
,
:
heads_slash_size
[
head_i
]]
intra_slash_indices
=
(
qk
.
size
(
-
1
)
-
1
)
-
slash_topk
[
slash_topk
>=
prev_chunk_end_pos
]
# fill buffer
v_count
=
intra_vertical_indices
.
nelement
()
s_count
=
intra_slash_indices
.
nelement
()
vertical_size_buffer
[
head_i
]
=
v_count
slash_sizes_buffer
[
head_i
]
=
s_count
vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
intra_vertical_indices
)
slash_buffer
[
head_i
,
:
s_count
].
copy_
(
intra_slash_indices
)
# succ
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
succ_vertical_indices
=
vertical_topk
[
(
vertical_topk
<
prev_chunk_end_pos
)
&
(
vertical_topk
>=
prev_chunk_end_pos
-
chunk_len
)
]
-
(
prev_chunk_end_pos
-
chunk_len
)
# TODO: support no vertical
if
succ_vertical_indices
.
nelement
()
==
0
:
succ_vertical_indices
=
torch
.
cat
(
[
succ_vertical_indices
,
torch
.
arange
(
0
,
k_states_succ
.
size
(
0
),
max
(
1
,
k_states_succ
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
succ_slash_indices
=
(
prev_chunk_end_pos
+
(
qend
-
qbegin
)
-
1
)
-
slash_topk
[
(
(
slash_topk
>=
(
prev_chunk_end_pos
-
chunk_len
))
&
(
slash_topk
<
(
prev_chunk_end_pos
+
(
qend
-
qbegin
)))
)
]
if
succ_slash_indices
.
nelement
()
==
0
:
succ_slash_indices
=
torch
.
cat
(
[
succ_slash_indices
,
torch
.
arange
(
0
,
k_states_succ
.
size
(
0
),
max
(
1
,
k_states_succ
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
# fill buffer
v_count
=
succ_vertical_indices
.
nelement
()
s_count
=
succ_slash_indices
.
nelement
()
succ_vertical_size_buffer
[
head_i
]
=
v_count
succ_slash_sizes_buffer
[
head_i
]
=
s_count
succ_vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
succ_vertical_indices
)
succ_slash_buffer
[
head_i
,
:
s_count
].
copy_
(
succ_slash_indices
)
if
prev_chunk_end_pos
-
2
*
chunk_len
>=
0
:
inter_vertical_indices
=
vertical_topk
[
vertical_topk
<
prev_chunk_end_pos
-
chunk_len
]
if
inter_vertical_indices
.
nelement
()
==
0
:
inter_vertical_indices
=
torch
.
cat
(
[
inter_vertical_indices
,
torch
.
arange
(
0
,
k_states_inter
.
size
(
0
),
max
(
1
,
k_states_inter
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
inter_slash_indices
=
(
prev_chunk_end_pos
-
chunk_len
+
(
qend
-
qbegin
)
-
1
)
-
slash_topk
[
slash_topk
<
(
prev_chunk_end_pos
-
chunk_len
+
(
qend
-
qbegin
))
]
if
inter_slash_indices
.
nelement
()
==
0
:
inter_slash_indices
=
torch
.
cat
(
[
inter_slash_indices
,
torch
.
arange
(
0
,
k_states_inter
.
size
(
0
),
max
(
1
,
k_states_inter
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
,
),
]
)
# fill buffer
v_count
=
inter_vertical_indices
.
nelement
()
s_count
=
inter_slash_indices
.
nelement
()
inter_vertical_size_buffer
[
head_i
]
=
v_count
inter_slash_sizes_buffer
[
head_i
]
=
s_count
inter_vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
inter_vertical_indices
)
inter_slash_buffer
[
head_i
,
:
s_count
].
copy_
(
inter_slash_indices
)
else
:
intra_vertical_indices
,
intra_slash_indices
=
None
,
None
succ_vertical_indices
,
succ_slash_indices
=
None
,
None
inter_vertical_indices
,
inter_slash_indices
=
None
,
None
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_intra
,
k_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
stage
=
"intra"
,
vertical_indices
=
vertical_buffer
,
slash_indices
=
slash_buffer
,
vertical_indices_count
=
vertical_size_buffer
,
slash_indices_count
=
slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_intra
,
k_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
stage
=
"intra"
,
vertical_indices
=
intra_vertical_indices
,
slash_indices
=
intra_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
flash_per_chunk
.
append
(
flash_result
)
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_succ
,
k_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_buffer
,
slash_indices
=
succ_slash_buffer
,
vertical_indices_count
=
succ_vertical_size_buffer
,
slash_indices_count
=
succ_slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_succ
,
k_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_indices
,
slash_indices
=
succ_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
flash_per_chunk
.
append
(
flash_result
)
if
prev_chunk_end_pos
-
chunk_len
*
2
>=
0
:
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_inter
,
k_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_buffer
,
slash_indices
=
inter_slash_buffer
,
vertical_indices_count
=
inter_vertical_size_buffer
,
slash_indices_count
=
inter_slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_inter
,
k_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_indices
,
slash_indices
=
inter_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
flash_per_chunk
.
append
(
flash_result
)
flash_results
.
append
(
flash_per_chunk
)
begin
=
end
attn_output
=
self
.
_merge_attn_outputs
(
flash_results
)
del
flash_results
return
attn_output
def
_do_flash_attn
(
self
,
query_states
:
torch
.
Tensor
,
key_states
:
torch
.
Tensor
,
value_states
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
=
True
,
max_seqlen_k
:
Optional
[
int
]
=
None
,
stage
:
str
=
"intra"
,
vertical_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
slash_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
vertical_indices_count
:
Optional
[
torch
.
Tensor
]
=
None
,
slash_indices_count
:
Optional
[
torch
.
Tensor
]
=
None
,
mergehead_softmax_scale
:
Optional
[
float
]
=
None
,
sparse_attn_enabled
:
Optional
[
bool
]
=
False
,
):
if
max_seqlen_k
is
None
:
max_seqlen_k
=
key_states
.
shape
[
0
]
q_len
=
query_states
.
shape
[
0
]
q_heads
=
query_states
.
shape
[
1
]
h_dim
=
query_states
.
shape
[
-
1
]
if
sparse_attn_enabled
:
assert
slash_indices
is
not
None
if
stage
==
"intra"
:
assert
causal
else
:
assert
not
causal
query_states
=
query_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
key_states
=
key_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
value_states
=
value_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
q
=
query_states
k
=
key_states
v
=
value_states
if
vertical_indices_count
is
not
None
and
slash_indices_count
is
not
None
:
assert
mergehead_softmax_scale
is
not
None
res
,
s_lse
=
_vertical_slash_sparse_attention
(
q
,
k
,
v
,
vertical_indices
,
slash_indices
,
mergehead_softmax_scale
,
causal
=
causal
,
stage
=
stage
,
vertical_indices_count
=
vertical_indices_count
,
slash_indices_count
=
slash_indices_count
,
)
res
=
res
.
view
(
q_heads
,
q_len
,
h_dim
).
transpose
(
0
,
1
)
# (qlen,nhead,h_dim)
s_lse
=
(
s_lse
.
view
(
q_heads
,
q_len
,
1
).
squeeze
(
-
1
).
unsqueeze
(
0
).
float
()
)
# (1, nhead,qlen)
else
:
res
,
s_lse
=
_vertical_slash_sparse_attention
(
q
,
k
,
v
,
vertical_indices
,
slash_indices
,
softmax_scale
,
causal
=
causal
,
stage
=
stage
,
)
res
=
res
.
view
(
q_len
,
q_heads
,
h_dim
)
s_lse
=
s_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
return
res
,
s_lse
output
,
softmax_lse
,
*
rest
=
flash_attn_varlen_func
(
q
=
query_states
,
k
=
key_states
,
v
=
value_states
,
softmax_scale
=
softmax_scale
,
cu_seqlens_q
=
torch
.
tensor
(
[
0
,
query_states
.
shape
[
0
]],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
,
),
max_seqlen_q
=
query_states
.
shape
[
0
],
cu_seqlens_k
=
torch
.
tensor
(
[
0
,
max_seqlen_k
],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
),
max_seqlen_k
=
max_seqlen_k
,
causal
=
causal
,
return_softmax_lse
=
True
,
)
softmax_lse
=
softmax_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
return
output
,
softmax_lse
def
_merge_attn_outputs
(
self
,
flash_results
:
List
[
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]],
return_lse
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
attn_outputs_all
=
[]
logits_all
=
[]
for
flash_per_chunk
in
flash_results
:
if
len
(
flash_per_chunk
)
==
1
:
attn_outputs_all
.
append
(
flash_per_chunk
[
0
][
0
])
if
return_lse
:
logits_all
.
append
(
flash_per_chunk
[
0
][
1
])
continue
attn_outputs
=
torch
.
stack
(
[
flash_attn_output
[
0
]
for
flash_attn_output
in
flash_per_chunk
]
)
logits
=
torch
.
stack
(
[
flash_attn_output
[
1
]
for
flash_attn_output
in
flash_per_chunk
]
)
logits
=
logits
.
to
(
torch
.
float32
)
if
return_lse
:
max_val
=
torch
.
max
(
logits
,
dim
=
0
).
values
diff
=
torch
.
abs
(
logits
[
0
]
-
logits
[
1
])
log_sum_exp
=
max_val
+
torch
.
log1p
(
torch
.
exp
(
-
diff
))
logits_all
.
append
(
log_sum_exp
)
max_logits
=
torch
.
max
(
logits
,
dim
=
0
).
values
stable_logits
=
logits
-
max_logits
.
unsqueeze
(
0
)
lse_s
=
torch
.
exp
(
stable_logits
).
detach
()
lse_sum
=
torch
.
sum
(
lse_s
,
dim
=
0
)
lse_s
/=
lse_sum
attn_outputs
*=
lse_s
.
unsqueeze
(
-
1
).
transpose
(
2
,
3
).
squeeze
(
1
)
attn_outputs_all
.
append
(
attn_outputs
.
sum
(
dim
=
0
))
if
return_lse
:
return
(
torch
.
cat
(
attn_outputs_all
,
dim
=
0
),
torch
.
cat
(
logits_all
,
dim
=-
1
))
else
:
return
torch
.
cat
(
attn_outputs_all
,
dim
=
0
)
def
_dual_chunk_flash_attn_decoding
(
self
,
query
:
torch
.
Tensor
,
query_succ
:
torch
.
Tensor
,
query_inter
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
,
chunk_size
:
int
,
local_size
:
int
,
original_max_position_embeddings
:
int
,
decode_meta
:
DualChunkFlashAttentionMetadata
,
):
if
not
causal
:
raise
ValueError
(
"Dual Chunk Attention does not support causal=False"
)
block_size
=
value_cache
.
shape
[
1
]
chunk_len
=
chunk_size
-
local_size
if
chunk_len
%
block_size
!=
0
:
raise
ValueError
(
"chunk_len must be divisible by block_size."
)
if
original_max_position_embeddings
>
0
:
assert
decode_meta
.
scaling_factor
is
not
None
scaling_factor
=
decode_meta
.
scaling_factor
query
=
(
query
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
# possible for numerical issue, need to fused in the kernel
query_succ
=
(
query_succ
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
query_inter
=
(
query_inter
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
outputs_list
=
[]
softmax_lses_list
=
[]
# intra-attention
intra_output
,
intra_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables_intra
,
decode_meta
.
seq_lens_intra
,
softmax_scale
,
causal
=
False
,
)
)
outputs_list
.
append
(
intra_output
)
softmax_lses_list
.
append
(
intra_softmax_lse
)
# succ-attention
if
decode_meta
.
max_seq_len_succ
:
succ_output
,
succ_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query_succ
,
key_cache
,
value_cache
,
decode_meta
.
block_tables_succ
,
decode_meta
.
seq_lens_succ
,
softmax_scale
,
causal
=
False
,
)
)
outputs_list
.
append
(
succ_output
)
softmax_lses_list
.
append
(
succ_softmax_lse
)
# inter-attention
if
decode_meta
.
max_seq_len_inter
:
inter_output
,
inter_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query_inter
,
key_cache
,
value_cache
,
block_table
[:,
:
decode_meta
.
max_seq_len_inter
],
decode_meta
.
seq_lens_inter
,
softmax_scale
,
causal
=
False
,
)
)
outputs_list
.
append
(
inter_output
)
softmax_lses_list
.
append
(
inter_softmax_lse
)
outputs
=
torch
.
stack
(
outputs_list
,
dim
=
0
)
del
outputs_list
softmax_lses
=
torch
.
stack
(
softmax_lses_list
,
dim
=
0
).
to
(
torch
.
float32
)
del
softmax_lses_list
max_logits
=
torch
.
max
(
softmax_lses
,
dim
=
0
).
values
stable_logits
=
softmax_lses
-
max_logits
.
unsqueeze
(
0
)
lse_s
=
torch
.
exp
(
stable_logits
).
detach
()
lse_sum
=
torch
.
sum
(
lse_s
,
dim
=
0
)
lse_s
/=
lse_sum
outputs
*=
lse_s
.
unsqueeze
(
-
1
).
transpose
(
2
,
3
)
return
outputs
.
sum
(
0
)
def
_dual_chunk_flash_attn_decoding_with_exp_sums
(
self
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
,
):
out
,
softmax_lse
,
*
rest_expand
=
flash_attn_with_kvcache
(
q
=
query
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
return_softmax_lse
=
True
,
)
mask
=
cache_seqlens
==
0
out
[
mask
]
=
0
softmax_lse
[
mask
]
=
-
float
(
"inf"
)
return
out
,
softmax_lse
def
_vertical_slash_sparse_attention
(
query
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_CTX, D_HEAD]
key
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
value
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
v_idx
:
torch
.
Tensor
,
# [BATCH, N_HEADS, NNZ_V]
s_idx
:
torch
.
Tensor
,
# [BATCH, N_HEADS, NNZ_S]
softmax_scale
:
float
,
causal
:
bool
=
True
,
stage
:
str
=
"intra"
,
block_size_M
:
int
=
64
,
block_size_N
:
int
=
64
,
vertical_indices_count
:
torch
.
Tensor
=
None
,
# [N_HEADS,]
slash_indices_count
:
torch
.
Tensor
=
None
,
):
if
stage
==
"intra"
:
assert
causal
else
:
assert
not
causal
batch_size
,
num_heads
,
context_size
,
head_dim
=
query
.
shape
_
,
_
,
kv_seq_len
,
_
=
key
.
shape
if
head_dim
not
in
[
16
,
32
,
64
,
128
,
256
,
512
]:
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
query
=
F
.
pad
(
query
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
key
=
F
.
pad
(
key
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
value
=
F
.
pad
(
value
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
v_idx
=
(
v_idx
.
to
(
torch
.
int32
)
.
reshape
((
batch_size
,
num_heads
,
-
1
))
.
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
)
s_idx
=
(
s_idx
.
to
(
torch
.
int32
)
.
reshape
((
batch_size
,
num_heads
,
-
1
))
.
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
)
q_seqlens
=
torch
.
tensor
([
context_size
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
kv_seqlens
=
torch
.
tensor
([
kv_seq_len
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
if
vertical_indices_count
is
not
None
and
slash_indices_count
is
not
None
:
(
block_count
,
block_offset
,
column_count
,
column_index
,
)
=
convert_vertical_slash_indexes_mergehead
(
q_seqlens
,
kv_seqlens
,
v_idx
,
s_idx
,
vertical_indices_count
,
slash_indices_count
,
context_size
,
block_size_M
,
block_size_N
,
causal
,
)
else
:
(
block_count
,
block_offset
,
column_count
,
column_index
,
)
=
convert_vertical_slash_indexes
(
q_seqlens
,
kv_seqlens
,
v_idx
,
s_idx
,
context_size
,
block_size_M
,
block_size_N
,
causal
,
)
q
=
query
.
transpose
(
1
,
2
).
contiguous
()
k
=
key
.
transpose
(
1
,
2
).
contiguous
()
v
=
value
.
transpose
(
1
,
2
).
contiguous
()
out
,
lse
=
sparse_attn_func
(
q
,
k
,
v
,
block_count
,
block_offset
,
column_count
,
column_index
,
causal
=
causal
,
softmax_scale
=
softmax_scale
,
return_softmax_lse
=
True
,
)
out
=
out
.
transpose
(
1
,
2
).
contiguous
()
softmax_lse
=
lse
.
reshape
(
*
lse
.
shape
,
1
)
return
(
out
[...,
:
context_size
,
:
head_dim
],
softmax_lse
[...,
:
context_size
,
:])
def
_sum_all_diagonal_matrix
(
mat
:
torch
.
tensor
):
h
,
n
,
m
=
mat
.
shape
# Zero matrix used for padding
zero_mat
=
torch
.
zeros
((
h
,
n
,
n
),
device
=
mat
.
device
)
# pads the matrix on left and right
mat_padded
=
torch
.
cat
((
zero_mat
,
mat
,
zero_mat
),
-
1
)
# Change the strides
mat_strided
=
mat_padded
.
as_strided
(
(
1
,
n
,
n
+
m
),
(
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
)
)
# Sums the resulting matrix's columns
sum_diags
=
torch
.
sum
(
mat_strided
,
1
)
return
sum_diags
[:,
1
:]
# drop left bottom corner
def
_get_block
(
block_table
:
torch
.
Tensor
,
block_size
:
int
,
begin
:
int
,
end
:
int
):
begin_block
=
begin
//
block_size
end_block
=
(
end
-
1
)
//
block_size
+
1
return
block_table
[
begin_block
:
end_block
]
python/sglang/srt/layers/rotary_embedding.py
View file @
b7cd7430
...
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
)
)
class
DualChunkRotaryEmbedding
(
CustomOp
):
"""Rotary positional embedding for Dual Chunk Attention."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
chunk_size
:
int
,
local_size
:
int
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
chunk_size
=
chunk_size
self
.
local_size
=
local_size
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
(
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
)
=
(
self
.
_compute_cos_sin_cache
()
)
self
.
register_buffer
(
"cos_sin_q_cache"
,
q_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_cache"
,
qc_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_k_cache"
,
k_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_no_clamp_cache"
,
qc_no_clamp_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_q_inter_cache"
,
q_inter_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
q_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
qc_t
=
(
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
).
clamp
(
max
=
self
.
chunk_size
)
k_t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
%
chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
self
.
chunk_size
q_freqs
=
torch
.
outer
(
q_t
,
inv_freq
)
qc_freqs
=
torch
.
outer
(
qc_t
,
inv_freq
)
k_freqs
=
torch
.
outer
(
k_t
,
inv_freq
)
qc_no_clamp_freqs
=
torch
.
outer
(
qc_no_clamp_t
,
inv_freq
)
q_inter_freqs
=
torch
.
outer
(
q_inter_t
,
inv_freq
)
q_cos
=
q_freqs
.
cos
()
q_sin
=
q_freqs
.
sin
()
qc_cos
=
qc_freqs
.
cos
()
qc_sin
=
qc_freqs
.
sin
()
k_cos
=
k_freqs
.
cos
()
k_sin
=
k_freqs
.
sin
()
qc_no_clamp_cos
=
qc_no_clamp_freqs
.
cos
()
qc_no_clamp_sin
=
qc_no_clamp_freqs
.
sin
()
q_inter_cos
=
q_inter_freqs
.
cos
()
q_inter_sin
=
q_inter_freqs
.
sin
()
q_cache
=
torch
.
cat
((
q_cos
,
q_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_cache
=
torch
.
cat
((
qc_cos
,
qc_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
k_cache
=
torch
.
cat
((
k_cos
,
k_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_no_clamp_cache
=
torch
.
cat
((
qc_no_clamp_cos
,
qc_no_clamp_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
q_inter_cache
=
torch
.
cat
((
q_inter_cos
,
q_inter_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
return
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
else
:
query_pass
=
None
key_pass
=
None
positions_with_offsets
=
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
)
key
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_k_cache
[
positions_with_offsets
],
key_rot
,
key_pass
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
query
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_succ
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
chunk_len
-
1
].
repeat
(
positions
.
shape
[
0
],
1
),
query_rot
,
query_pass
,
)
query_succ_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_no_clamp_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_inter_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
# merge query into one tensor to simplify the interfaces
query
=
torch
.
cat
(
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
),
dim
=-
1
,
)
return
query
,
key
def
_apply_rotary_embedding
(
self
,
cos_sin
,
hidden_rot
,
hidden_pass
):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
hidden_rot
=
hidden_rot
*
cos
+
rotate_fn
(
hidden_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
hidden
=
torch
.
cat
((
hidden_rot
,
hidden_pass
),
dim
=-
1
)
else
:
hidden
=
hidden_rot
return
hidden
.
flatten
(
-
2
).
squeeze
(
0
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
s
+=
f
", chunk_size=
{
self
.
chunk_size
}
, local_size=
{
self
.
local_size
}
"
return
s
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -1184,6 +1380,7 @@ def get_rope(
...
@@ -1184,6 +1380,7 @@ def get_rope(
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
partial_rotary_factor
:
float
=
1.0
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
...
@@ -1195,6 +1392,17 @@ def get_rope(
...
@@ -1195,6 +1392,17 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
else
:
rope_scaling_args
=
None
rope_scaling_args
=
None
if
dual_chunk_attention_config
is
not
None
:
dual_chunk_attention_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
!=
"sparse_attention_config"
}
dual_chunk_attention_args
=
tuple
(
dual_chunk_attention_tuple
.
items
())
else
:
dual_chunk_attention_args
=
None
if
partial_rotary_factor
<
1.0
:
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
key
=
(
...
@@ -1204,12 +1412,28 @@ def get_rope(
...
@@ -1204,12 +1412,28 @@ def get_rope(
base
,
base
,
is_neox_style
,
is_neox_style
,
rope_scaling_args
,
rope_scaling_args
,
dual_chunk_attention_args
,
dtype
,
dtype
,
)
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
dual_chunk_attention_config
is
not
None
:
extra_kwargs
=
{
k
:
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
in
(
"chunk_size"
,
"local_size"
)
}
rotary_emb
=
DualChunkRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
**
extra_kwargs
,
)
elif
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b7cd7430
...
@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
=
None
seq_lens_sum
:
int
=
None
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int32
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
...
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[
len
(
r
.
fill_ids
)
for
r
in
reqs
]
seq_lens
=
[
len
(
r
.
fill_ids
)
for
r
in
reqs
]
orig_seq_lens
=
[
max
(
len
(
r
.
fill_ids
),
len
(
r
.
origin_input_ids
))
for
r
in
reqs
]
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
...
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
...
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
input_ids
=
input_ids_tensor
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
input_embeds
=
(
self
.
input_embeds
=
(
torch
.
tensor
(
input_embeds
).
to
(
self
.
device
,
non_blocking
=
True
)
torch
.
tensor
(
input_embeds
).
to
(
self
.
device
,
non_blocking
=
True
)
...
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
seq_lens_sum
=
0
self
.
seq_lens_sum
=
0
...
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
# Do not use in-place operations in the overlap mode
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
orig_seq_lens
=
self
.
orig_seq_lens
+
1
else
:
else
:
# A faster in-place version
# A faster in-place version
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
self
.
orig_seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
self
.
seq_lens_sum
+=
bs
# free memory
# free memory
...
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
orig_seq_lens
=
self
.
orig_seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
output_ids
=
self
.
output_ids
[
keep_indices_device
]
self
.
output_ids
=
self
.
output_ids
[
keep_indices_device
]
...
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
)
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
orig_seq_lens
=
torch
.
cat
([
self
.
orig_seq_lens
,
other
.
orig_seq_lens
])
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
if
self
.
output_ids
is
not
None
:
if
self
.
output_ids
is
not
None
:
...
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
self
.
input_ids
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
,
orig_seq_lens
=
self
.
orig_seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_sum
=
self
.
seq_lens_sum
,
seq_lens_sum
=
self
.
seq_lens_sum
,
...
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
...
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
# Sampling info
# Sampling info
sampling_info
:
SamplingBatchInfo
sampling_info
:
SamplingBatchInfo
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# The input Embeds
# The input Embeds
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b7cd7430
...
@@ -589,6 +589,7 @@ class CudaGraphRunner:
...
@@ -589,6 +589,7 @@ class CudaGraphRunner:
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
next_token_logits_buffer
=
next_token_logits_buffer
,
orig_seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b7cd7430
...
@@ -180,6 +180,9 @@ class ForwardBatch:
...
@@ -180,6 +180,9 @@ class ForwardBatch:
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
seq_lens_sum
:
int
# The original sequence length without being chunked. Qwen-1M related.
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# Optional seq_lens on cpu
# Optional seq_lens on cpu
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -321,6 +324,7 @@ class ForwardBatch:
...
@@ -321,6 +324,7 @@ class ForwardBatch:
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
orig_seq_lens
=
batch
.
orig_seq_lens
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b7cd7430
...
@@ -1467,6 +1467,12 @@ class ModelRunner:
...
@@ -1467,6 +1467,12 @@ class ModelRunner:
logger
.
info
(
f
"Intel AMX attention backend is enabled."
)
logger
.
info
(
f
"Intel AMX attention backend is enabled."
)
return
IntelAMXAttnBackend
(
self
)
return
IntelAMXAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"dual_chunk_flash_attn"
:
from
sglang.srt.layers.attention.dual_chunk_flashattention_backend
import
(
DualChunkFlashAttentionBackend
,
)
return
DualChunkFlashAttentionBackend
(
self
)
else
:
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
...
...
python/sglang/srt/models/qwen2.py
View file @
b7cd7430
...
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
...
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
32768
,
max_position_embeddings
:
int
=
32768
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
...
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
...
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
32768
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
32768
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2Attention
(
self
.
self_attn
=
Qwen2Attention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
self
.
mlp
=
Qwen2MLP
(
self
.
mlp
=
Qwen2MLP
(
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
b7cd7430
...
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
qkv_bias
:
int
=
True
,
qkv_bias
:
int
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
...
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
qkv_bias
=
getattr
(
config
,
"qkv_bias"
,
True
)
qkv_bias
=
getattr
(
config
,
"qkv_bias"
,
True
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2MoeAttention
(
self
.
self_attn
=
Qwen2MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
qkv_bias
=
qkv_bias
,
qkv_bias
=
qkv_bias
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
b7cd7430
...
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
attention_bias
:
bool
=
False
,
attention_bias
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
...
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
)
)
rms_norm_eps
=
config
.
rms_norm_eps
rms_norm_eps
=
config
.
rms_norm_eps
attention_bias
=
config
.
attention_bias
attention_bias
=
config
.
attention_bias
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen3MoeAttention
(
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
attention_bias
=
attention_bias
,
attention_bias
=
attention_bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
dual_chunk_attention_config
=
dual_chunk_attention_config
,
alt_stream
=
alt_stream
,
alt_stream
=
alt_stream
,
)
)
...
...
python/sglang/srt/server_args.py
View file @
b7cd7430
...
@@ -502,6 +502,20 @@ class ServerArgs:
...
@@ -502,6 +502,20 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels
# use bf16 for mxfp4 triton kernels
self
.
dtype
=
"bfloat16"
self
.
dtype
=
"bfloat16"
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
logger
.
warning
(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Cuda graph is disabled because of using dual chunk flash attention backend"
)
self
.
enable_mixed_chunk
=
False
self
.
disable_cuda_graph
=
True
self
.
disable_radix_cache
=
True
# Set page size
# Set page size
if
self
.
page_size
is
None
:
if
self
.
page_size
is
None
:
self
.
page_size
=
1
self
.
page_size
=
1
...
@@ -1337,6 +1351,7 @@ class ServerArgs:
...
@@ -1337,6 +1351,7 @@ class ServerArgs:
"triton"
,
"triton"
,
"trtllm_mla"
,
"trtllm_mla"
,
"trtllm_mha"
,
"trtllm_mha"
,
"dual_chunk_flash_attn"
,
],
],
default
=
ServerArgs
.
attention_backend
,
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
help
=
"Choose the kernels for attention layers."
,
...
...
python/sglang/srt/two_batch_overlap.py
View file @
b7cd7430
...
@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
...
@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
"padded_static_len"
,
"padded_static_len"
,
"mrope_positions"
,
# only used by qwen2-vl, thus not care
"mrope_positions"
,
# only used by qwen2-vl, thus not care
"split_index"
,
# for split prefill
"split_index"
,
# for split prefill
"orig_seq_lens"
,
# only used by qwen-1m, thus not care
]:
]:
output_dict
[
key
]
=
getattr
(
batch
,
key
)
output_dict
[
key
]
=
getattr
(
batch
,
key
)
if
not
batch
.
forward_mode
.
is_target_verify
():
if
not
batch
.
forward_mode
.
is_target_verify
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment