Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7cd7430
Unverified
Commit
b7cd7430
authored
Aug 07, 2025
by
PGFLMG
Committed by
GitHub
Aug 06, 2025
Browse files
[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
parent
a69b6370
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2121 additions
and
4 deletions
+2121
-4
examples/runtime/engine/offline_batch_inference_qwen_1m.py
examples/runtime/engine/offline_batch_inference_qwen_1m.py
+74
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+28
-0
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+3
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+30
-3
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+1700
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+225
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+6
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+6
-0
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+6
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-0
No files found.
examples/runtime/engine/offline_batch_inference_qwen_1m.py
0 → 100644
View file @
b7cd7430
"""
Usage:
python3 offline_batch_inference.py
"""
from
urllib.request
import
urlopen
import
sglang
as
sgl
def
load_prompt
()
->
str
:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with
urlopen
(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/64k.txt"
,
timeout
=
5
,
)
as
response
:
prompt
=
response
.
read
().
decode
(
"utf-8"
)
return
prompt
# Processing the prompt.
def
process_requests
(
llm
:
sgl
.
Engine
,
prompts
:
list
[
str
])
->
None
:
# Create a sampling params object.
sampling_params
=
{
"temperature"
:
0.7
,
"top_p"
:
0.8
,
"top_k"
:
20
,
"repetition_penalty"
:
1.05
,
"max_new_tokens"
:
256
,
}
# Generate texts from the prompts.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt_token_ids
=
output
[
"meta_info"
][
"prompt_tokens"
]
generated_text
=
output
[
"text"
]
print
(
f
"Prompt length:
{
prompt_token_ids
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
# Create an LLM.
def
initialize_engine
()
->
sgl
.
Engine
:
llm
=
sgl
.
Engine
(
model_path
=
"Qwen/Qwen2.5-7B-Instruct-1M"
,
context_length
=
1048576
,
page_size
=
256
,
attention_backend
=
"dual_chunk_flash_attn"
,
tp_size
=
4
,
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
,
enable_torch_compile
=
False
,
chunked_prefill_size
=
131072
,
mem_fraction_static
=
0.6
,
log_level
=
"DEBUG"
,
)
return
llm
def
main
():
llm
=
initialize_engine
()
prompt
=
load_prompt
()
process_requests
(
llm
,
[
prompt
])
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/configs/model_config.py
View file @
b7cd7430
...
...
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
get_context_length
,
get_generation_config
,
get_hf_text_config
,
get_sparse_attention_config
,
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -270,6 +271,9 @@ class ModelConfig:
# Verify quantization
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
...
...
@@ -297,6 +301,13 @@ class ModelConfig:
**
kwargs
,
)
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
num_attention_heads
def
get_num_attention_heads
(
self
,
tensor_parallel_size
)
->
int
:
total_num_attention_heads
=
self
.
num_attention_heads
return
max
(
1
,
total_num_attention_heads
//
tensor_parallel_size
)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
...
...
@@ -484,6 +495,23 @@ class ModelConfig:
self
.
quantization
,
)
def
_verify_dual_chunk_attention_config
(
self
)
->
None
:
if
hasattr
(
self
.
hf_config
,
"dual_chunk_attention_config"
):
# Try loading the sparse attention config
sparse_attn_config
=
get_sparse_attention_config
(
self
.
model_path
)
if
not
sparse_attn_config
:
return
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_config"
]
=
(
sparse_attn_config
)
if
(
"sparse_attention_enabled"
not
in
self
.
hf_config
.
dual_chunk_attention_config
):
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_enabled"
]
=
True
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
if
eos_ids
is
not
None
:
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
b7cd7430
...
...
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
b7cd7430
...
...
@@ -14,10 +14,11 @@
"""Utilities for Huggingface Transformers."""
import
contextlib
import
json
import
os
import
warnings
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
import
torch
from
huggingface_hub
import
snapshot_download
...
...
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig
.
register
(
name
,
cls
)
def
download_from_hf
(
model_path
:
str
):
def
download_from_hf
(
model_path
:
str
,
allow_patterns
:
Optional
[
Union
[
str
,
list
]]
=
None
,
):
if
os
.
path
.
exists
(
model_path
):
return
model_path
return
snapshot_download
(
model_path
,
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
])
if
not
allow_patterns
:
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
]
return
snapshot_download
(
model_path
,
allow_patterns
=
allow_patterns
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
...
...
@@ -171,6 +178,26 @@ def get_generation_config(
return
None
# Qwen-1M related
def
get_sparse_attention_config
(
model
:
str
,
sparse_attention_config_filename
:
str
=
"sparse_attention_config.json"
,
)
->
Dict
[
str
,
Any
]:
is_local
=
os
.
path
.
isdir
(
model
)
if
not
is_local
:
# Download the config files.
model
=
download_from_hf
(
model
,
allow_patterns
=
[
"*.json"
])
config_file
=
os
.
path
.
join
(
model
,
sparse_attention_config_filename
)
if
not
os
.
path
.
exists
(
config_file
):
return
{}
# Load the sparse attention config.
with
open
(
config_file
)
as
f
:
config
=
json
.
load
(
f
)
return
config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
...
...
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
0 → 100644
View file @
b7cd7430
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/rotary_embedding.py
View file @
b7cd7430
...
...
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
)
class
DualChunkRotaryEmbedding
(
CustomOp
):
"""Rotary positional embedding for Dual Chunk Attention."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
chunk_size
:
int
,
local_size
:
int
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
chunk_size
=
chunk_size
self
.
local_size
=
local_size
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
(
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
)
=
(
self
.
_compute_cos_sin_cache
()
)
self
.
register_buffer
(
"cos_sin_q_cache"
,
q_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_cache"
,
qc_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_k_cache"
,
k_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_no_clamp_cache"
,
qc_no_clamp_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_q_inter_cache"
,
q_inter_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
q_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
qc_t
=
(
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
).
clamp
(
max
=
self
.
chunk_size
)
k_t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
%
chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
self
.
chunk_size
q_freqs
=
torch
.
outer
(
q_t
,
inv_freq
)
qc_freqs
=
torch
.
outer
(
qc_t
,
inv_freq
)
k_freqs
=
torch
.
outer
(
k_t
,
inv_freq
)
qc_no_clamp_freqs
=
torch
.
outer
(
qc_no_clamp_t
,
inv_freq
)
q_inter_freqs
=
torch
.
outer
(
q_inter_t
,
inv_freq
)
q_cos
=
q_freqs
.
cos
()
q_sin
=
q_freqs
.
sin
()
qc_cos
=
qc_freqs
.
cos
()
qc_sin
=
qc_freqs
.
sin
()
k_cos
=
k_freqs
.
cos
()
k_sin
=
k_freqs
.
sin
()
qc_no_clamp_cos
=
qc_no_clamp_freqs
.
cos
()
qc_no_clamp_sin
=
qc_no_clamp_freqs
.
sin
()
q_inter_cos
=
q_inter_freqs
.
cos
()
q_inter_sin
=
q_inter_freqs
.
sin
()
q_cache
=
torch
.
cat
((
q_cos
,
q_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_cache
=
torch
.
cat
((
qc_cos
,
qc_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
k_cache
=
torch
.
cat
((
k_cos
,
k_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_no_clamp_cache
=
torch
.
cat
((
qc_no_clamp_cos
,
qc_no_clamp_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
q_inter_cache
=
torch
.
cat
((
q_inter_cos
,
q_inter_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
return
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
else
:
query_pass
=
None
key_pass
=
None
positions_with_offsets
=
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
)
key
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_k_cache
[
positions_with_offsets
],
key_rot
,
key_pass
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
query
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_succ
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
chunk_len
-
1
].
repeat
(
positions
.
shape
[
0
],
1
),
query_rot
,
query_pass
,
)
query_succ_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_no_clamp_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_inter_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
# merge query into one tensor to simplify the interfaces
query
=
torch
.
cat
(
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
),
dim
=-
1
,
)
return
query
,
key
def
_apply_rotary_embedding
(
self
,
cos_sin
,
hidden_rot
,
hidden_pass
):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
hidden_rot
=
hidden_rot
*
cos
+
rotate_fn
(
hidden_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
hidden
=
torch
.
cat
((
hidden_rot
,
hidden_pass
),
dim
=-
1
)
else
:
hidden
=
hidden_rot
return
hidden
.
flatten
(
-
2
).
squeeze
(
0
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
s
+=
f
", chunk_size=
{
self
.
chunk_size
}
, local_size=
{
self
.
local_size
}
"
return
s
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
...
@@ -1184,6 +1380,7 @@ def get_rope(
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -1195,6 +1392,17 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
dual_chunk_attention_config
is
not
None
:
dual_chunk_attention_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
!=
"sparse_attention_config"
}
dual_chunk_attention_args
=
tuple
(
dual_chunk_attention_tuple
.
items
())
else
:
dual_chunk_attention_args
=
None
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
...
...
@@ -1204,12 +1412,28 @@ def get_rope(
base
,
is_neox_style
,
rope_scaling_args
,
dual_chunk_attention_args
,
dtype
,
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
dual_chunk_attention_config
is
not
None
:
extra_kwargs
=
{
k
:
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
in
(
"chunk_size"
,
"local_size"
)
}
rotary_emb
=
DualChunkRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
**
extra_kwargs
,
)
elif
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b7cd7430
...
...
@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# The sum of all sequence lengths
seq_lens_sum
:
int
=
None
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int32
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
...
...
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[
len
(
r
.
fill_ids
)
for
r
in
reqs
]
orig_seq_lens
=
[
max
(
len
(
r
.
fill_ids
),
len
(
r
.
origin_input_ids
))
for
r
in
reqs
]
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
...
...
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
...
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
out_cache_loc
=
out_cache_loc
self
.
input_embeds
=
(
torch
.
tensor
(
input_embeds
).
to
(
self
.
device
,
non_blocking
=
True
)
...
...
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
seq_lens_sum
=
0
...
...
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
orig_seq_lens
=
self
.
orig_seq_lens
+
1
else
:
# A faster in-place version
self
.
seq_lens
.
add_
(
1
)
self
.
orig_seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
# free memory
...
...
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
orig_seq_lens
=
self
.
orig_seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
output_ids
=
self
.
output_ids
[
keep_indices_device
]
...
...
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
orig_seq_lens
=
torch
.
cat
([
self
.
orig_seq_lens
,
other
.
orig_seq_lens
])
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
if
self
.
output_ids
is
not
None
:
...
...
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
orig_seq_lens
=
self
.
orig_seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_sum
=
self
.
seq_lens_sum
,
...
...
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info
:
SamplingBatchInfo
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# The input Embeds
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b7cd7430
...
...
@@ -589,6 +589,7 @@ class CudaGraphRunner:
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
orig_seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b7cd7430
...
...
@@ -180,6 +180,9 @@ class ForwardBatch:
# The sum of all sequence lengths
seq_lens_sum
:
int
# The original sequence length without being chunked. Qwen-1M related.
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# Optional seq_lens on cpu
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -321,6 +324,7 @@ class ForwardBatch:
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
orig_seq_lens
=
batch
.
orig_seq_lens
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b7cd7430
...
...
@@ -1467,6 +1467,12 @@ class ModelRunner:
logger
.
info
(
f
"Intel AMX attention backend is enabled."
)
return
IntelAMXAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"dual_chunk_flash_attn"
:
from
sglang.srt.layers.attention.dual_chunk_flashattention_backend
import
(
DualChunkFlashAttentionBackend
,
)
return
DualChunkFlashAttentionBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
...
...
python/sglang/srt/models/qwen2.py
View file @
b7cd7430
...
...
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
32768
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
...
...
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
32768
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
self
.
mlp
=
Qwen2MLP
(
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
b7cd7430
...
...
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings
:
int
=
8192
,
qkv_bias
:
int
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
...
...
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
qkv_bias
=
getattr
(
config
,
"qkv_bias"
,
True
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2MoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
qkv_bias
=
qkv_bias
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
b7cd7430
...
...
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
attention_bias
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
...
...
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
)
rms_norm_eps
=
config
.
rms_norm_eps
attention_bias
=
config
.
attention_bias
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
attention_bias
=
attention_bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
dual_chunk_attention_config
=
dual_chunk_attention_config
,
alt_stream
=
alt_stream
,
)
...
...
python/sglang/srt/server_args.py
View file @
b7cd7430
...
...
@@ -502,6 +502,20 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels
self
.
dtype
=
"bfloat16"
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
logger
.
warning
(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Cuda graph is disabled because of using dual chunk flash attention backend"
)
self
.
enable_mixed_chunk
=
False
self
.
disable_cuda_graph
=
True
self
.
disable_radix_cache
=
True
# Set page size
if
self
.
page_size
is
None
:
self
.
page_size
=
1
...
...
@@ -1337,6 +1351,7 @@ class ServerArgs:
"triton"
,
"trtllm_mla"
,
"trtllm_mha"
,
"dual_chunk_flash_attn"
,
],
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
...
...
python/sglang/srt/two_batch_overlap.py
View file @
b7cd7430
...
...
@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
"padded_static_len"
,
"mrope_positions"
,
# only used by qwen2-vl, thus not care
"split_index"
,
# for split prefill
"orig_seq_lens"
,
# only used by qwen-1m, thus not care
]:
output_dict
[
key
]
=
getattr
(
batch
,
key
)
if
not
batch
.
forward_mode
.
is_target_verify
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment