Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7cd7430
Unverified
Commit
b7cd7430
authored
Aug 07, 2025
by
PGFLMG
Committed by
GitHub
Aug 06, 2025
Browse files
[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
parent
a69b6370
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2121 additions
and
4 deletions
+2121
-4
examples/runtime/engine/offline_batch_inference_qwen_1m.py
examples/runtime/engine/offline_batch_inference_qwen_1m.py
+74
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+28
-0
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+3
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+30
-3
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+1700
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+225
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+6
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+6
-0
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+6
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-0
No files found.
examples/runtime/engine/offline_batch_inference_qwen_1m.py
0 → 100644
View file @
b7cd7430
"""
Usage:
python3 offline_batch_inference.py
"""
from
urllib.request
import
urlopen
import
sglang
as
sgl
def
load_prompt
()
->
str
:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with
urlopen
(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/64k.txt"
,
timeout
=
5
,
)
as
response
:
prompt
=
response
.
read
().
decode
(
"utf-8"
)
return
prompt
# Processing the prompt.
def
process_requests
(
llm
:
sgl
.
Engine
,
prompts
:
list
[
str
])
->
None
:
# Create a sampling params object.
sampling_params
=
{
"temperature"
:
0.7
,
"top_p"
:
0.8
,
"top_k"
:
20
,
"repetition_penalty"
:
1.05
,
"max_new_tokens"
:
256
,
}
# Generate texts from the prompts.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt_token_ids
=
output
[
"meta_info"
][
"prompt_tokens"
]
generated_text
=
output
[
"text"
]
print
(
f
"Prompt length:
{
prompt_token_ids
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
# Create an LLM.
def
initialize_engine
()
->
sgl
.
Engine
:
llm
=
sgl
.
Engine
(
model_path
=
"Qwen/Qwen2.5-7B-Instruct-1M"
,
context_length
=
1048576
,
page_size
=
256
,
attention_backend
=
"dual_chunk_flash_attn"
,
tp_size
=
4
,
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
,
enable_torch_compile
=
False
,
chunked_prefill_size
=
131072
,
mem_fraction_static
=
0.6
,
log_level
=
"DEBUG"
,
)
return
llm
def
main
():
llm
=
initialize_engine
()
prompt
=
load_prompt
()
process_requests
(
llm
,
[
prompt
])
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/configs/model_config.py
View file @
b7cd7430
...
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
get_context_length
,
get_context_length
,
get_generation_config
,
get_generation_config
,
get_hf_text_config
,
get_hf_text_config
,
get_sparse_attention_config
,
)
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -270,6 +271,9 @@ class ModelConfig:
...
@@ -270,6 +271,9 @@ class ModelConfig:
# Verify quantization
# Verify quantization
self
.
_verify_quantization
()
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
# Cache attributes
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
...
@@ -297,6 +301,13 @@ class ModelConfig:
...
@@ -297,6 +301,13 @@ class ModelConfig:
**
kwargs
,
**
kwargs
,
)
)
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
num_attention_heads
def
get_num_attention_heads
(
self
,
tensor_parallel_size
)
->
int
:
total_num_attention_heads
=
self
.
num_attention_heads
return
max
(
1
,
total_num_attention_heads
//
tensor_parallel_size
)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
"""Returns the total number of KV heads."""
...
@@ -484,6 +495,23 @@ class ModelConfig:
...
@@ -484,6 +495,23 @@ class ModelConfig:
self
.
quantization
,
self
.
quantization
,
)
)
def
_verify_dual_chunk_attention_config
(
self
)
->
None
:
if
hasattr
(
self
.
hf_config
,
"dual_chunk_attention_config"
):
# Try loading the sparse attention config
sparse_attn_config
=
get_sparse_attention_config
(
self
.
model_path
)
if
not
sparse_attn_config
:
return
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_config"
]
=
(
sparse_attn_config
)
if
(
"sparse_attention_enabled"
not
in
self
.
hf_config
.
dual_chunk_attention_config
):
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_enabled"
]
=
True
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
if
eos_ids
is
not
None
:
if
eos_ids
is
not
None
:
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
b7cd7430
...
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
seq_lens_sum
=
sum
(
seq_lens
)
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
b7cd7430
...
@@ -14,10 +14,11 @@
...
@@ -14,10 +14,11 @@
"""Utilities for Huggingface Transformers."""
"""Utilities for Huggingface Transformers."""
import
contextlib
import
contextlib
import
json
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
import
torch
import
torch
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
...
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
...
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig
.
register
(
name
,
cls
)
AutoConfig
.
register
(
name
,
cls
)
def
download_from_hf
(
model_path
:
str
):
def
download_from_hf
(
model_path
:
str
,
allow_patterns
:
Optional
[
Union
[
str
,
list
]]
=
None
,
):
if
os
.
path
.
exists
(
model_path
):
if
os
.
path
.
exists
(
model_path
):
return
model_path
return
model_path
return
snapshot_download
(
model_path
,
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
])
if
not
allow_patterns
:
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
]
return
snapshot_download
(
model_path
,
allow_patterns
=
allow_patterns
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
def
get_hf_text_config
(
config
:
PretrainedConfig
):
...
@@ -171,6 +178,26 @@ def get_generation_config(
...
@@ -171,6 +178,26 @@ def get_generation_config(
return
None
return
None
# Qwen-1M related
def
get_sparse_attention_config
(
model
:
str
,
sparse_attention_config_filename
:
str
=
"sparse_attention_config.json"
,
)
->
Dict
[
str
,
Any
]:
is_local
=
os
.
path
.
isdir
(
model
)
if
not
is_local
:
# Download the config files.
model
=
download_from_hf
(
model
,
allow_patterns
=
[
"*.json"
])
config_file
=
os
.
path
.
join
(
model
,
sparse_attention_config_filename
)
if
not
os
.
path
.
exists
(
config_file
):
return
{}
# Load the sparse attention config.
with
open
(
config_file
)
as
f
:
config
=
json
.
load
(
f
)
return
config
# Models don't use the same configuration key for determining the maximum
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# NOTE: The ordering here is important. Some models have two of these and we
...
...
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
0 → 100644
View file @
b7cd7430
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/rotary_embedding.py
View file @
b7cd7430
...
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
)
)
class
DualChunkRotaryEmbedding
(
CustomOp
):
"""Rotary positional embedding for Dual Chunk Attention."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
chunk_size
:
int
,
local_size
:
int
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
chunk_size
=
chunk_size
self
.
local_size
=
local_size
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
(
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
)
=
(
self
.
_compute_cos_sin_cache
()
)
self
.
register_buffer
(
"cos_sin_q_cache"
,
q_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_cache"
,
qc_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_k_cache"
,
k_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_no_clamp_cache"
,
qc_no_clamp_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_q_inter_cache"
,
q_inter_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
q_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
qc_t
=
(
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
).
clamp
(
max
=
self
.
chunk_size
)
k_t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
%
chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
self
.
chunk_size
q_freqs
=
torch
.
outer
(
q_t
,
inv_freq
)
qc_freqs
=
torch
.
outer
(
qc_t
,
inv_freq
)
k_freqs
=
torch
.
outer
(
k_t
,
inv_freq
)
qc_no_clamp_freqs
=
torch
.
outer
(
qc_no_clamp_t
,
inv_freq
)
q_inter_freqs
=
torch
.
outer
(
q_inter_t
,
inv_freq
)
q_cos
=
q_freqs
.
cos
()
q_sin
=
q_freqs
.
sin
()
qc_cos
=
qc_freqs
.
cos
()
qc_sin
=
qc_freqs
.
sin
()
k_cos
=
k_freqs
.
cos
()
k_sin
=
k_freqs
.
sin
()
qc_no_clamp_cos
=
qc_no_clamp_freqs
.
cos
()
qc_no_clamp_sin
=
qc_no_clamp_freqs
.
sin
()
q_inter_cos
=
q_inter_freqs
.
cos
()
q_inter_sin
=
q_inter_freqs
.
sin
()
q_cache
=
torch
.
cat
((
q_cos
,
q_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_cache
=
torch
.
cat
((
qc_cos
,
qc_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
k_cache
=
torch
.
cat
((
k_cos
,
k_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_no_clamp_cache
=
torch
.
cat
((
qc_no_clamp_cos
,
qc_no_clamp_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
q_inter_cache
=
torch
.
cat
((
q_inter_cos
,
q_inter_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
return
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
else
:
query_pass
=
None
key_pass
=
None
positions_with_offsets
=
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
)
key
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_k_cache
[
positions_with_offsets
],
key_rot
,
key_pass
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
query
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_succ
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
chunk_len
-
1
].
repeat
(
positions
.
shape
[
0
],
1
),
query_rot
,
query_pass
,
)
query_succ_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_no_clamp_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
query_inter_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_inter_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
,
)
# merge query into one tensor to simplify the interfaces
query
=
torch
.
cat
(
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
),
dim
=-
1
,
)
return
query
,
key
def
_apply_rotary_embedding
(
self
,
cos_sin
,
hidden_rot
,
hidden_pass
):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
hidden_rot
=
hidden_rot
*
cos
+
rotate_fn
(
hidden_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
hidden
=
torch
.
cat
((
hidden_rot
,
hidden_pass
),
dim
=-
1
)
else
:
hidden
=
hidden_rot
return
hidden
.
flatten
(
-
2
).
squeeze
(
0
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
s
+=
f
", chunk_size=
{
self
.
chunk_size
}
, local_size=
{
self
.
local_size
}
"
return
s
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -1184,6 +1380,7 @@ def get_rope(
...
@@ -1184,6 +1380,7 @@ def get_rope(
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
partial_rotary_factor
:
float
=
1.0
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
...
@@ -1195,6 +1392,17 @@ def get_rope(
...
@@ -1195,6 +1392,17 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
else
:
rope_scaling_args
=
None
rope_scaling_args
=
None
if
dual_chunk_attention_config
is
not
None
:
dual_chunk_attention_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
!=
"sparse_attention_config"
}
dual_chunk_attention_args
=
tuple
(
dual_chunk_attention_tuple
.
items
())
else
:
dual_chunk_attention_args
=
None
if
partial_rotary_factor
<
1.0
:
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
key
=
(
...
@@ -1204,12 +1412,28 @@ def get_rope(
...
@@ -1204,12 +1412,28 @@ def get_rope(
base
,
base
,
is_neox_style
,
is_neox_style
,
rope_scaling_args
,
rope_scaling_args
,
dual_chunk_attention_args
,
dtype
,
dtype
,
)
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
dual_chunk_attention_config
is
not
None
:
extra_kwargs
=
{
k
:
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
in
(
"chunk_size"
,
"local_size"
)
}
rotary_emb
=
DualChunkRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
**
extra_kwargs
,
)
elif
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b7cd7430
...
@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
=
None
seq_lens_sum
:
int
=
None
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int32
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
...
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[
len
(
r
.
fill_ids
)
for
r
in
reqs
]
seq_lens
=
[
len
(
r
.
fill_ids
)
for
r
in
reqs
]
orig_seq_lens
=
[
max
(
len
(
r
.
fill_ids
),
len
(
r
.
origin_input_ids
))
for
r
in
reqs
]
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
...
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
...
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
input_ids
=
input_ids_tensor
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
input_embeds
=
(
self
.
input_embeds
=
(
torch
.
tensor
(
input_embeds
).
to
(
self
.
device
,
non_blocking
=
True
)
torch
.
tensor
(
input_embeds
).
to
(
self
.
device
,
non_blocking
=
True
)
...
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
seq_lens_sum
=
0
self
.
seq_lens_sum
=
0
...
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
# Do not use in-place operations in the overlap mode
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
orig_seq_lens
=
self
.
orig_seq_lens
+
1
else
:
else
:
# A faster in-place version
# A faster in-place version
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
self
.
orig_seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
self
.
seq_lens_sum
+=
bs
# free memory
# free memory
...
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
orig_seq_lens
=
self
.
orig_seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
output_ids
=
self
.
output_ids
[
keep_indices_device
]
self
.
output_ids
=
self
.
output_ids
[
keep_indices_device
]
...
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
)
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
orig_seq_lens
=
torch
.
cat
([
self
.
orig_seq_lens
,
other
.
orig_seq_lens
])
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
if
self
.
output_ids
is
not
None
:
if
self
.
output_ids
is
not
None
:
...
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids
=
self
.
input_ids
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
,
orig_seq_lens
=
self
.
orig_seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_sum
=
self
.
seq_lens_sum
,
seq_lens_sum
=
self
.
seq_lens_sum
,
...
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
...
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
# Sampling info
# Sampling info
sampling_info
:
SamplingBatchInfo
sampling_info
:
SamplingBatchInfo
# The original sequence lengths, Qwen-1M related
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# The input Embeds
# The input Embeds
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b7cd7430
...
@@ -589,6 +589,7 @@ class CudaGraphRunner:
...
@@ -589,6 +589,7 @@ class CudaGraphRunner:
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
next_token_logits_buffer
=
next_token_logits_buffer
,
orig_seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b7cd7430
...
@@ -180,6 +180,9 @@ class ForwardBatch:
...
@@ -180,6 +180,9 @@ class ForwardBatch:
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
seq_lens_sum
:
int
# The original sequence length without being chunked. Qwen-1M related.
orig_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
# Optional seq_lens on cpu
# Optional seq_lens on cpu
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -321,6 +324,7 @@ class ForwardBatch:
...
@@ -321,6 +324,7 @@ class ForwardBatch:
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
orig_seq_lens
=
batch
.
orig_seq_lens
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b7cd7430
...
@@ -1467,6 +1467,12 @@ class ModelRunner:
...
@@ -1467,6 +1467,12 @@ class ModelRunner:
logger
.
info
(
f
"Intel AMX attention backend is enabled."
)
logger
.
info
(
f
"Intel AMX attention backend is enabled."
)
return
IntelAMXAttnBackend
(
self
)
return
IntelAMXAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"dual_chunk_flash_attn"
:
from
sglang.srt.layers.attention.dual_chunk_flashattention_backend
import
(
DualChunkFlashAttentionBackend
,
)
return
DualChunkFlashAttentionBackend
(
self
)
else
:
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
raise
ValueError
(
f
"Invalid attention backend:
{
backend_str
}
"
)
...
...
python/sglang/srt/models/qwen2.py
View file @
b7cd7430
...
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
...
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
32768
,
max_position_embeddings
:
int
=
32768
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
...
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
...
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
32768
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
32768
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2Attention
(
self
.
self_attn
=
Qwen2Attention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
self
.
mlp
=
Qwen2MLP
(
self
.
mlp
=
Qwen2MLP
(
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
b7cd7430
...
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
qkv_bias
:
int
=
True
,
qkv_bias
:
int
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
...
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
qkv_bias
=
getattr
(
config
,
"qkv_bias"
,
True
)
qkv_bias
=
getattr
(
config
,
"qkv_bias"
,
True
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen2MoeAttention
(
self
.
self_attn
=
Qwen2MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
qkv_bias
=
qkv_bias
,
qkv_bias
=
qkv_bias
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
b7cd7430
...
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
attention_bias
:
bool
=
False
,
attention_bias
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
...
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
)
)
rms_norm_eps
=
config
.
rms_norm_eps
rms_norm_eps
=
config
.
rms_norm_eps
attention_bias
=
config
.
attention_bias
attention_bias
=
config
.
attention_bias
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen3MoeAttention
(
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
attention_bias
=
attention_bias
,
attention_bias
=
attention_bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
dual_chunk_attention_config
=
dual_chunk_attention_config
,
alt_stream
=
alt_stream
,
alt_stream
=
alt_stream
,
)
)
...
...
python/sglang/srt/server_args.py
View file @
b7cd7430
...
@@ -502,6 +502,20 @@ class ServerArgs:
...
@@ -502,6 +502,20 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels
# use bf16 for mxfp4 triton kernels
self
.
dtype
=
"bfloat16"
self
.
dtype
=
"bfloat16"
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
logger
.
warning
(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger
.
warning
(
"Cuda graph is disabled because of using dual chunk flash attention backend"
)
self
.
enable_mixed_chunk
=
False
self
.
disable_cuda_graph
=
True
self
.
disable_radix_cache
=
True
# Set page size
# Set page size
if
self
.
page_size
is
None
:
if
self
.
page_size
is
None
:
self
.
page_size
=
1
self
.
page_size
=
1
...
@@ -1337,6 +1351,7 @@ class ServerArgs:
...
@@ -1337,6 +1351,7 @@ class ServerArgs:
"triton"
,
"triton"
,
"trtllm_mla"
,
"trtllm_mla"
,
"trtllm_mha"
,
"trtllm_mha"
,
"dual_chunk_flash_attn"
,
],
],
default
=
ServerArgs
.
attention_backend
,
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
help
=
"Choose the kernels for attention layers."
,
...
...
python/sglang/srt/two_batch_overlap.py
View file @
b7cd7430
...
@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
...
@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
"padded_static_len"
,
"padded_static_len"
,
"mrope_positions"
,
# only used by qwen2-vl, thus not care
"mrope_positions"
,
# only used by qwen2-vl, thus not care
"split_index"
,
# for split prefill
"split_index"
,
# for split prefill
"orig_seq_lens"
,
# only used by qwen-1m, thus not care
]:
]:
output_dict
[
key
]
=
getattr
(
batch
,
key
)
output_dict
[
key
]
=
getattr
(
batch
,
key
)
if
not
batch
.
forward_mode
.
is_target_verify
():
if
not
batch
.
forward_mode
.
is_target_verify
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment