Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
852a49c5
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
295 additions
and
1379 deletions
+295
-1379
python/pyproject.toml
python/pyproject.toml
+4
-4
python/pyproject_other.toml
python/pyproject_other.toml
+5
-5
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+8
-6
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+32
-305
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+1
-7
python/sglang/environ.py
python/sglang/environ.py
+0
-2
python/sglang/global_config.py
python/sglang/global_config.py
+2
-2
python/sglang/launch_server.py
python/sglang/launch_server.py
+0
-14
python/sglang/srt/configs/load_config.py
python/sglang/srt/configs/load_config.py
+0
-8
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+131
-123
python/sglang/srt/configs/qwen3_vl.py
python/sglang/srt/configs/qwen3_vl.py
+0
-586
python/sglang/srt/disaggregation/ascend/transfer_engine.py
python/sglang/srt/disaggregation/ascend/transfer_engine.py
+47
-9
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+8
-27
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+6
-21
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
...lang/srt/disaggregation/decode_kvcache_offload_manager.py
+0
-185
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+15
-23
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+31
-14
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+0
-2
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+4
-35
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
No files found.
python/pyproject.toml
View file @
852a49c5
...
@@ -57,7 +57,7 @@ dependencies = [
...
@@ -57,7 +57,7 @@ dependencies = [
"uvicorn"
,
"uvicorn"
,
"uvloop"
,
"uvloop"
,
"xgrammar==0.1.24"
,
"xgrammar==0.1.24"
,
"sgl-kernel==0.3.1
3
"
,
"sgl-kernel==0.3.1
1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
...
@@ -67,7 +67,7 @@ dependencies = [
...
@@ -67,7 +67,7 @@ dependencies = [
"tiktoken"
,
"tiktoken"
,
"anthropic>=0.20.0"
,
"anthropic>=0.20.0"
,
"torch_memory_saver==0.0.8"
,
"torch_memory_saver==0.0.8"
,
"nvidia-cutlass-dsl==4.2.
1
"
,
"nvidia-cutlass-dsl==4.2.
0
"
,
]
]
[project.optional-dependencies]
[project.optional-dependencies]
...
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
...
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json"
,
"srt/layers/moe/fused_moe_triton/configs/*/*.json"
,
"srt/layers/quantization/configs/*.json"
,
"srt/layers/quantization/configs/*.json"
,
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp"
,
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp"
,
"srt/speculative/cpp_
ngram
/*.cpp"
,
"srt/speculative/cpp_
lookahead
/*.cpp"
,
"srt/speculative/cpp_
ngram
/*.h"
,
"srt/speculative/cpp_
lookahead
/*.h"
,
]
]
[tool.setuptools.packages.find]
[tool.setuptools.packages.find]
...
...
python/pyproject_other.toml
View file @
852a49c5
...
@@ -65,23 +65,23 @@ tracing = [
...
@@ -65,23 +65,23 @@ tracing = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.3.1
3
"
,
"sgl-kernel==0.3.1
1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
"cuda-python"
,
"cuda-python"
,
"flashinfer_python==0.
4.0rc
1"
,
"flashinfer_python==0.
3.
1"
,
]
]
blackwell
=
[
blackwell
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.3.1
3
"
,
"sgl-kernel==0.3.1
1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
"cuda-python"
,
"cuda-python"
,
"flashinfer_python==0.
4.0rc
1"
,
"flashinfer_python==0.
3.
1"
,
"nvidia-cutlass-dsl==4.2.
1
"
,
"nvidia-cutlass-dsl==4.2.
0
"
,
]
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# HIP (Heterogeneous-computing Interface for Portability) for AMD
...
...
python/sglang/bench_one_batch.py
View file @
852a49c5
...
@@ -443,9 +443,11 @@ def latency_test_run_once(
...
@@ -443,9 +443,11 @@ def latency_test_run_once(
if
profile
:
if
profile
:
profiler
.
stop
()
profiler
.
stop
()
trace_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_prefill.trace.json.gz"
profile_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_prefill.trace.json.gz"
_save_profile_trace_results
(
profiler
,
trace_filename
)
_save_profile_trace_results
(
profiler
,
profile_filename
)
rank_print
(
f
"torch profiler chrome trace for prefill saved to
{
trace_filename
}
"
)
rank_print
(
f
"torch profiler chrome trace for prefill saved to
{
profile_filename
}
"
)
# Decode
# Decode
decode_latencies
=
[]
decode_latencies
=
[]
...
@@ -477,10 +479,10 @@ def latency_test_run_once(
...
@@ -477,10 +479,10 @@ def latency_test_run_once(
if
profile
and
i
==
output_len
/
2
:
if
profile
and
i
==
output_len
/
2
:
profiler
.
stop
()
profiler
.
stop
()
trac
e_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_decode.trace.json.gz"
profil
e_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_decode.trace.json.gz"
_save_profile_trace_results
(
profiler
,
trac
e_filename
)
_save_profile_trace_results
(
profiler
,
profil
e_filename
)
rank_print
(
rank_print
(
f
"torch profiler chrome trace for decoding 1 token saved to
{
trac
e_filename
}
"
f
"torch profiler chrome trace for decoding 1 token saved to
{
profil
e_filename
}
"
)
)
# Record decode timing from 2nd output
# Record decode timing from 2nd output
...
...
python/sglang/bench_one_batch_server.py
View file @
852a49c5
...
@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
...
@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
"""
"""
import
argparse
import
argparse
...
@@ -20,17 +19,12 @@ import multiprocessing
...
@@ -20,17 +19,12 @@ import multiprocessing
import
os
import
os
import
random
import
random
import
time
import
time
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
from
pydantic
import
BaseModel
from
sglang.bench_serving
import
(
from
sglang.bench_serving
import
get_tokenizer
,
sample_random_requests
get_tokenizer
,
sample_mmmu_requests
,
sample_random_requests
,
)
from
sglang.profiler
import
run_profile
from
sglang.profiler
import
run_profile
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
...
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
class
ProfileLinks
(
BaseModel
):
"""Pydantic model for profile trace links."""
extend
:
Optional
[
str
]
=
None
decode
:
Optional
[
str
]
=
None
class
BenchmarkResult
(
BaseModel
):
"""Pydantic model for benchmark results table data, for a single isl and osl"""
model_path
:
str
run_name
:
str
batch_size
:
int
input_len
:
int
output_len
:
int
latency
:
float
ttft
:
float
input_throughput
:
float
output_throughput
:
float
overall_throughput
:
float
last_gen_throughput
:
float
acc_length
:
Optional
[
float
]
=
None
profile_links
:
Optional
[
ProfileLinks
]
=
None
@
staticmethod
def
help_str
()
->
str
:
return
f
"""
Note: To view the traces through perfetto-ui, please:
1. open with Google Chrome
2. allow popup
"""
def
to_markdown_row
(
self
,
trace_dir
,
base_url
:
str
=
""
,
relay_base
:
str
=
""
)
->
str
:
"""Convert this benchmark result to a markdown table row."""
# Calculate costs (assuming H100 pricing for now)
hourly_cost_per_gpu
=
2
# $2/hour for one H100
hourly_cost
=
hourly_cost_per_gpu
*
1
# Assuming tp_size = 1 for simplicity
input_util
=
0.7
accept_length
=
(
round
(
self
.
acc_length
,
2
)
if
self
.
acc_length
is
not
None
else
"n/a"
)
itl
=
1
/
(
self
.
output_throughput
/
self
.
batch_size
)
*
1000
input_cost
=
1e6
/
(
self
.
input_throughput
*
input_util
)
/
3600
*
hourly_cost
output_cost
=
1e6
/
self
.
output_throughput
/
3600
*
hourly_cost
def
get_perfetto_relay_link_from_trace_file
(
trace_file
:
str
):
import
os
from
urllib.parse
import
quote
rel_path
=
os
.
path
.
relpath
(
trace_file
,
trace_dir
)
raw_file_link
=
f
"
{
base_url
}
/
{
rel_path
}
"
relay_link
=
(
f
"
{
relay_base
}
?src=
{
quote
(
raw_file_link
,
safe
=
''
)
}
"
if
relay_base
and
quote
else
raw_file_link
)
return
relay_link
# Handle profile links
profile_link
=
"NA | NA"
if
self
.
profile_links
:
if
self
.
profile_links
.
extend
or
self
.
profile_links
.
decode
:
# Create a combined link or use the first available one
trace_files
=
[
self
.
profile_links
.
extend
,
self
.
profile_links
.
decode
]
trace_files_relay_links
=
[
f
"[trace](
{
get_perfetto_relay_link_from_trace_file
(
trace_file
)
}
)"
for
trace_file
in
trace_files
]
profile_link
=
" | "
.
join
(
trace_files_relay_links
)
# Build the row
return
f
"|
{
self
.
batch_size
}
|
{
self
.
input_len
}
|
{
self
.
latency
:.
2
f
}
|
{
self
.
input_throughput
:.
2
f
}
|
{
self
.
output_throughput
:.
2
f
}
|
{
accept_length
}
|
{
itl
:.
2
f
}
|
{
input_cost
:.
2
f
}
|
{
output_cost
:.
2
f
}
|
{
profile_link
}
|
\n
"
@
classmethod
def
generate_markdown_report
(
cls
,
trace_dir
,
results
:
List
[
"BenchmarkResult"
]
)
->
str
:
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
import
os
summary
=
f
"###
{
results
[
0
].
model_path
}
\n
"
# summary += (
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
# )
summary
+=
"| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|
\n
"
summary
+=
"| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |
\n
"
# all results should share the same isl & osl
for
result
in
results
:
base_url
=
os
.
getenv
(
"TRACE_BASE_URL"
,
""
).
rstrip
(
"/"
)
relay_base
=
os
.
getenv
(
"PERFETTO_RELAY_URL"
,
""
).
rstrip
(
"/"
)
relay_base
=
"https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
# base_url = "https://github.com/sgl-project/ci-data/traces"
summary
+=
result
.
to_markdown_row
(
trace_dir
,
base_url
,
relay_base
)
return
summary
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BenchArgs
:
class
BenchArgs
:
run_name
:
str
=
"default"
run_name
:
str
=
"default"
...
@@ -158,12 +50,8 @@ class BenchArgs:
...
@@ -158,12 +50,8 @@ class BenchArgs:
profile
:
bool
=
False
profile
:
bool
=
False
profile_steps
:
int
=
3
profile_steps
:
int
=
3
profile_by_stage
:
bool
=
False
profile_by_stage
:
bool
=
False
profile_filename_prefix
:
str
=
None
append_to_github_summary
:
bool
=
True
dataset_path
:
str
=
""
dataset_path
:
str
=
""
parallel_batch
:
bool
=
False
parallel_batch
:
bool
=
False
dataset_name
:
str
=
"random"
output_path
:
Optional
[
str
]
=
None
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -179,13 +67,6 @@ class BenchArgs:
...
@@ -179,13 +67,6 @@ class BenchArgs:
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
)
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--dataset-name"
,
type
=
str
,
default
=
BenchArgs
.
dataset_name
,
choices
=
[
"mmmu"
,
"random"
],
help
=
"Name of the dataset to benchmark on."
,
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--client-stream-interval"
,
"--client-stream-interval"
,
...
@@ -215,36 +96,14 @@ class BenchArgs:
...
@@ -215,36 +96,14 @@ class BenchArgs:
help
=
"Path to the dataset."
,
help
=
"Path to the dataset."
,
)
)
parser
.
add_argument
(
"--parallel-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--parallel-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--profile-filename-prefix"
,
type
=
str
,
default
=
BenchArgs
.
profile_filename_prefix
,
)
parser
.
add_argument
(
"--no-append-to-github-summary"
,
action
=
"store_false"
,
dest
=
"append_to_github_summary"
,
help
=
"Disable appending the output of this run to github ci summary"
,
)
parser
.
add_argument
(
"--output-path"
,
type
=
str
,
default
=
BenchArgs
.
output_path
,
help
=
"Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
# use the default value's type to cast the args into correct types.
# use the default value's type to cast the args into correct types.
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
kwargs
=
{}
return
cls
(
for
attr
,
attr_type
in
attrs
:
**
{
attr
:
attr_type
(
getattr
(
args
,
attr
))
for
attr
,
attr_type
in
attrs
}
val
=
getattr
(
args
,
attr
)
)
if
attr_type
is
type
(
None
):
kwargs
[
attr
]
=
val
else
:
kwargs
[
attr
]
=
attr_type
(
val
)
return
cls
(
**
kwargs
)
def
launch_server_internal
(
server_args
):
def
launch_server_internal
(
server_args
):
...
@@ -289,35 +148,23 @@ def run_one_case(
...
@@ -289,35 +148,23 @@ def run_one_case(
run_name
:
str
,
run_name
:
str
,
result_filename
:
str
,
result_filename
:
str
,
tokenizer
,
tokenizer
,
dataset_name
=
""
,
profile
:
bool
=
False
,
profile
:
bool
=
False
,
profile_steps
:
int
=
3
,
profile_steps
:
int
=
3
,
profile_by_stage
:
bool
=
False
,
profile_by_stage
:
bool
=
False
,
profile_filename_prefix
:
str
=
None
,
dataset_path
:
str
=
""
,
dataset_path
:
str
=
""
,
parallel_batch
:
bool
=
False
,
parallel_batch
:
bool
=
False
,
):
):
requests
.
post
(
url
+
"/flush_cache"
)
requests
.
post
(
url
+
"/flush_cache"
)
# TODO: reuse bench_serving.get_dataset ?
input_requests
=
sample_random_requests
(
if
dataset_name
==
"mmmu"
:
input_len
=
input_len
,
input_requests
=
sample_mmmu_requests
(
output_len
=
output_len
,
num_requests
=
batch_size
,
num_prompts
=
batch_size
,
tokenizer
=
tokenizer
,
range_ratio
=
1.0
,
fixed_output_len
=
output_len
,
tokenizer
=
tokenizer
,
apply_chat_template
=
True
,
dataset_path
=
dataset_path
,
random_sample
=
False
,
random_sample
=
True
,
)
return_text
=
False
,
elif
dataset_name
==
"random"
:
)
input_requests
=
sample_random_requests
(
input_len
=
input_len
,
output_len
=
output_len
,
num_prompts
=
batch_size
,
range_ratio
=
1.0
,
tokenizer
=
tokenizer
,
dataset_path
=
dataset_path
,
random_sample
=
True
,
return_text
=
False
,
)
use_structured_outputs
=
False
use_structured_outputs
=
False
if
use_structured_outputs
:
if
use_structured_outputs
:
...
@@ -334,48 +181,26 @@ def run_one_case(
...
@@ -334,48 +181,26 @@ def run_one_case(
profile_link
=
None
profile_link
=
None
if
profile
:
if
profile
:
output_dir
,
profile_name
=
None
,
None
if
profile_filename_prefix
:
output_dir
=
os
.
path
.
dirname
(
profile_filename_prefix
)
profile_name
=
os
.
path
.
basename
(
profile_filename_prefix
)
profile_link
:
str
=
run_profile
(
profile_link
:
str
=
run_profile
(
url
,
url
,
profile_steps
,
[
"CPU"
,
"GPU"
],
None
,
None
,
profile_by_stage
profile_steps
,
[
"CPU"
,
"GPU"
],
output_dir
,
profile_name
,
profile_by_stage
,
)
)
tic
=
time
.
perf_counter
()
tic
=
time
.
perf_counter
()
payload
=
{
"sampling_params"
:
{
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
"stream_interval"
:
stream_interval
,
},
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
**
({
"parallel_batch"
:
parallel_batch
}
if
parallel_batch
else
{}),
}
if
dataset_name
==
"mmmu"
:
# vlm
input_ids
=
[]
for
input_req
in
input_requests
:
input_ids
+=
[
tokenizer
.
encode
(
input_req
.
prompt
)]
payload
[
"image_data"
]
=
[
req
.
image_data
for
req
in
input_requests
]
else
:
input_ids
=
[
req
.
prompt
for
req
in
input_requests
]
payload
[
"input_ids"
]
=
input_ids
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
payload
,
json
=
{
"input_ids"
:
[
req
.
prompt
for
req
in
input_requests
],
"sampling_params"
:
{
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
"stream_interval"
:
stream_interval
,
},
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
**
({
"parallel_batch"
:
parallel_batch
}
if
parallel_batch
else
{}),
},
stream
=
True
,
stream
=
True
,
)
)
...
@@ -439,100 +264,10 @@ def run_one_case(
...
@@ -439,100 +264,10 @@ def run_one_case(
overall_throughput
,
overall_throughput
,
last_gen_throughput
,
last_gen_throughput
,
acc_length
,
acc_length
,
profile_link
,
profile_link
if
profile
else
None
,
)
)
def
save_results_as_json
(
result
:
List
[
Tuple
],
bench_args
:
BenchArgs
,
model
:
str
):
"""Save benchmark results as JSON using Pydantic models."""
json_results
=
[]
# Generate all parameter combinations to match with results
param_combinations
=
list
(
itertools
.
product
(
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
)
)
for
i
,
(
batch_size
,
latency
,
ttft
,
input_throughput
,
output_throughput
,
overall_throughput
,
last_gen_throughput
,
acc_length
,
profile_link
,
)
in
enumerate
(
result
):
# Get the corresponding parameters for this result
bs
,
input_len
,
output_len
=
param_combinations
[
i
]
# Parse profile links if available
profile_links
=
None
if
profile_link
:
profile_links
=
parse_profile_links
(
profile_link
,
batch_size
,
input_len
,
output_len
)
benchmark_result
=
BenchmarkResult
(
model_path
=
model
,
run_name
=
bench_args
.
run_name
,
batch_size
=
batch_size
,
input_len
=
input_len
,
output_len
=
output_len
,
latency
=
latency
,
ttft
=
ttft
,
input_throughput
=
input_throughput
,
output_throughput
=
output_throughput
,
overall_throughput
=
overall_throughput
,
last_gen_throughput
=
last_gen_throughput
,
acc_length
=
acc_length
,
profile_links
=
profile_links
,
)
json_results
.
append
(
benchmark_result
.
model_dump
())
# Save to JSON file
with
open
(
bench_args
.
output_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
json_results
,
f
,
indent
=
2
,
ensure_ascii
=
False
)
print
(
f
"Results saved as JSON to
{
bench_args
.
output_path
}
"
)
def
parse_profile_links
(
profile_dir
:
str
,
batch_size
:
int
,
input_len
:
int
,
output_len
:
int
)
->
Optional
[
ProfileLinks
]:
"""Parse profile directory to extract extend and decode trace file links."""
if
not
profile_dir
or
not
os
.
path
.
exists
(
profile_dir
):
return
None
extend_link
=
None
decode_link
=
None
# Look for extend/prefill trace files
for
file
in
os
.
listdir
(
profile_dir
):
if
file
.
endswith
(
".trace.json.gz"
)
or
file
.
endswith
(
".trace.json"
):
if
"extend"
in
file
.
lower
()
or
"prefill"
in
file
.
lower
():
extend_link
=
os
.
path
.
join
(
profile_dir
,
file
)
elif
"decode"
in
file
.
lower
():
decode_link
=
os
.
path
.
join
(
profile_dir
,
file
)
# If no specific extend/decode files found, try to find files with batch/input/output info
if
not
extend_link
or
not
decode_link
:
for
file
in
os
.
listdir
(
profile_dir
):
if
file
.
endswith
(
".trace.json.gz"
)
or
file
.
endswith
(
".trace.json"
):
if
f
"_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_"
in
file
:
if
"prefill"
in
file
.
lower
()
or
"extend"
in
file
.
lower
():
extend_link
=
os
.
path
.
join
(
profile_dir
,
file
)
elif
"decode"
in
file
.
lower
():
decode_link
=
os
.
path
.
join
(
profile_dir
,
file
)
if
extend_link
or
decode_link
:
return
ProfileLinks
(
extend
=
extend_link
,
decode
=
decode_link
)
return
None
def
get_report_summary
(
def
get_report_summary
(
result
:
List
[
Tuple
],
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
result
:
List
[
Tuple
],
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
):
):
...
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
return_logprob
=
bench_args
.
return_logprob
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
dataset_name
=
bench_args
.
dataset_name
,
run_name
=
""
,
run_name
=
""
,
result_filename
=
""
,
result_filename
=
""
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
stream_interval
=
bench_args
.
client_stream_interval
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
dataset_name
=
bench_args
.
dataset_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
dataset_path
=
bench_args
.
dataset_path
,
dataset_path
=
bench_args
.
dataset_path
,
parallel_batch
=
bench_args
.
parallel_batch
,
parallel_batch
=
bench_args
.
parallel_batch
,
profile_filename_prefix
=
bench_args
.
profile_filename_prefix
,
)
)
)
)
...
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
dataset_name
=
bench_args
.
dataset_name
,
profile
=
bench_args
.
profile
,
profile
=
bench_args
.
profile
,
profile_steps
=
bench_args
.
profile_steps
,
profile_steps
=
bench_args
.
profile_steps
,
profile_by_stage
=
bench_args
.
profile_by_stage
,
profile_by_stage
=
bench_args
.
profile_by_stage
,
dataset_path
=
bench_args
.
dataset_path
,
dataset_path
=
bench_args
.
dataset_path
,
parallel_batch
=
bench_args
.
parallel_batch
,
parallel_batch
=
bench_args
.
parallel_batch
,
profile_filename_prefix
=
bench_args
.
profile_filename_prefix
,
)[
-
1
],
)[
-
1
],
)
)
)
)
...
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
# Save results as JSON if output_path is specified
if
bench_args
.
output_path
:
save_results_as_json
(
result
,
bench_args
,
model
=
server_args
.
model_path
)
if
not
bench_args
.
show_report
:
if
not
bench_args
.
show_report
:
return
return
summary
=
get_report_summary
(
result
,
server_args
,
bench_args
)
summary
=
get_report_summary
(
result
,
server_args
,
bench_args
)
print
(
summary
)
if
is_in_ci
()
and
bench_args
.
append_to_github_summary
:
if
is_in_ci
():
write_github_step_summary
(
summary
)
write_github_step_summary
(
summary
)
...
...
python/sglang/bench_serving.py
View file @
852a49c5
...
@@ -208,10 +208,6 @@ async def async_request_openai_completions(
...
@@ -208,10 +208,6 @@ async def async_request_openai_completions(
"ignore_eos"
:
not
args
.
disable_ignore_eos
,
"ignore_eos"
:
not
args
.
disable_ignore_eos
,
**
request_func_input
.
extra_request_body
,
**
request_func_input
.
extra_request_body
,
}
}
if
request_func_input
.
image_data
:
payload
.
update
({
"image_data"
:
request_func_input
.
image_data
})
headers
=
get_auth_headers
()
headers
=
get_auth_headers
()
output
=
RequestFuncOutput
.
init_new
(
request_func_input
)
output
=
RequestFuncOutput
.
init_new
(
request_func_input
)
...
@@ -1763,9 +1759,7 @@ async def benchmark(
...
@@ -1763,9 +1759,7 @@ async def benchmark(
pbar
.
close
()
pbar
.
close
()
if
"sglang"
in
backend
:
if
"sglang"
in
backend
:
server_info
=
requests
.
get
(
server_info
=
requests
.
get
(
base_url
+
"/get_server_info"
)
base_url
+
"/get_server_info"
,
headers
=
get_auth_headers
()
)
if
server_info
.
status_code
==
200
:
if
server_info
.
status_code
==
200
:
server_info_json
=
server_info
.
json
()
server_info_json
=
server_info
.
json
()
if
"decode"
in
server_info_json
:
if
"decode"
in
server_info_json
:
...
...
python/sglang/
srt/
environ.py
→
python/sglang/environ.py
View file @
852a49c5
...
@@ -124,8 +124,6 @@ class Envs:
...
@@ -124,8 +124,6 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS
=
EnvBool
(
False
)
SGLANG_TEST_REQUEST_TIME_STATS
=
EnvBool
(
False
)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK
=
EnvBool
(
False
)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK
=
EnvBool
(
False
)
SGLANG_DISABLE_REQUEST_LOGGING
=
EnvBool
(
False
)
SGLANG_DISABLE_REQUEST_LOGGING
=
EnvBool
(
False
)
SGLANG_SIMULATE_ACC_LEN
=
EnvFloat
(
-
1
)
SGLANG_SIMULATE_ACC_METHOD
=
EnvStr
(
"multinomial"
)
# Model Parallel
# Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER
=
EnvBool
(
True
)
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER
=
EnvBool
(
True
)
...
...
python/sglang/global_config.py
View file @
852a49c5
...
@@ -37,8 +37,8 @@ class GlobalConfig:
...
@@ -37,8 +37,8 @@ class GlobalConfig:
)
)
# Runtime constants: others
# Runtime constants: others
self
.
retract_decode_steps
=
20
self
.
retract_decode_steps
=
20
self
.
flashinfer_workspace_size
=
in
t
(
self
.
flashinfer_workspace_size
=
os
.
environ
.
ge
t
(
os
.
environ
.
get
(
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
)
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
)
)
# Output tokenization configs
# Output tokenization configs
...
...
python/sglang/launch_server.py
View file @
852a49c5
...
@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
...
@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
MOVE_ENVS_WARN
=
"""
########################################################################
# For contributors and developers: #
# Please move environment variable definitions to sglang.srt.environ #
# using the following pattern: #
# SGLANG_XXX = EnvBool(False) #
# #
########################################################################
"""
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
from
sglang.srt.server_args
import
print_deprecated_warning
print_deprecated_warning
(
MOVE_ENVS_WARN
)
try
:
try
:
launch_server
(
server_args
)
launch_server
(
server_args
)
finally
:
finally
:
...
...
python/sglang/srt/configs/load_config.py
View file @
852a49c5
...
@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
...
@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
JAX
=
"jax"
JAX
=
"jax"
REMOTE
=
"remote"
REMOTE
=
"remote"
REMOTE_INSTANCE
=
"remote_instance"
REMOTE_INSTANCE
=
"remote_instance"
RDMA
=
"rdma"
LOCAL_CACHED
=
"local_cached"
@
dataclass
@
dataclass
...
@@ -49,7 +47,6 @@ class LoadConfig:
...
@@ -49,7 +47,6 @@ class LoadConfig:
checkpoints.
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
from this file (after PBKDF2).
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
"""
"""
load_format
:
Union
[
str
,
LoadFormat
]
=
LoadFormat
.
AUTO
load_format
:
Union
[
str
,
LoadFormat
]
=
LoadFormat
.
AUTO
...
@@ -57,11 +54,6 @@ class LoadConfig:
...
@@ -57,11 +54,6 @@ class LoadConfig:
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
decryption_key_file
:
Optional
[
str
]
=
None
decryption_key_file
:
Optional
[
str
]
=
None
decrypt_max_concurrency
:
int
=
-
1
tp_rank
:
Optional
[
int
]
=
None
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
...
...
python/sglang/srt/configs/model_config.py
View file @
852a49c5
...
@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
)
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
retry
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
from
sglang.utils
import
is_in_ci
from
sglang.utils
import
is_in_ci
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
...
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS
=
"transformers"
TRANSFORMERS
=
"transformers"
def
is_deepseek_nsa
(
config
:
PretrainedConfig
)
->
bool
:
return
(
config
.
architectures
is
not
None
and
config
.
architectures
[
0
]
in
[
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
]
and
getattr
(
config
,
"index_topk"
,
None
)
is
not
None
)
def
get_nsa_index_head_dim
(
config
:
PretrainedConfig
)
->
int
:
assert
is_deepseek_nsa
(
config
)
return
config
.
index_head_dim
def
get_nsa_index_topk
(
config
:
PretrainedConfig
)
->
int
:
assert
is_deepseek_nsa
(
config
)
return
config
.
index_topk
def
get_nsa_index_n_heads
(
config
:
PretrainedConfig
)
->
int
:
assert
is_deepseek_nsa
(
config
)
return
config
.
index_n_heads
class
ModelConfig
:
class
ModelConfig
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -64,20 +88,35 @@ class ModelConfig:
...
@@ -64,20 +88,35 @@ class ModelConfig:
is_draft_model
:
bool
=
False
,
is_draft_model
:
bool
=
False
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
tp_rank
:
Optional
[
int
]
=
None
,
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
,
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
,
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]
]
=
None
,
)
->
None
:
)
->
None
:
# Parse args
# Parse args
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
revision
=
revision
self
.
revision
=
revision
self
.
quantization
=
quantization
self
.
quantization
=
quantization
self
.
is_draft_model
=
is_draft_model
self
.
model_impl
=
model_impl
self
.
model_impl
=
model_impl
self
.
tp_rank
=
tp_rank
self
.
remote_instance_weight_loader_seed_instance_ip
=
(
remote_instance_weight_loader_seed_instance_ip
)
self
.
remote_instance_weight_loader_seed_instance_service_port
=
(
remote_instance_weight_loader_seed_instance_service_port
)
self
.
remote_instance_weight_loader_send_weights_group_ports
=
(
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config
self
.
maybe_pull_model_tokenizer_from_remote
()
self
.
_maybe_pull_model_tokenizer_from_remote
()
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
kwargs
=
{}
kwargs
=
{}
if
override_config_file
and
override_config_file
.
strip
():
if
override_config_file
and
override_config_file
.
strip
():
kwargs
[
"_configuration_file"
]
=
override_config_file
.
strip
()
kwargs
[
"_configuration_file"
]
=
override_config_file
.
strip
()
self
.
hf_config
=
get_config
(
self
.
hf_config
=
get_config
(
self
.
model_path
,
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -85,7 +124,7 @@ class ModelConfig:
...
@@ -85,7 +124,7 @@ class ModelConfig:
model_override_args
=
self
.
model_override_args
,
model_override_args
=
self
.
model_override_args
,
**
kwargs
,
**
kwargs
,
)
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_generation_config
=
get_generation_config
(
self
.
hf_generation_config
=
get_generation_config
(
self
.
model_path
,
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -93,25 +132,7 @@ class ModelConfig:
...
@@ -93,25 +132,7 @@ class ModelConfig:
**
kwargs
,
**
kwargs
,
)
)
# Set enable_multimodal
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
if
enable_multimodal
is
None
:
mm_disabled_models
=
[
"Gemma3ForConditionalGeneration"
,
"Llama4ForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
]
if
self
.
hf_config
.
architectures
[
0
]
in
mm_disabled_models
:
enable_multimodal
=
False
logger
.
info
(
f
"Multimodal is disabled for
{
self
.
hf_config
.
model_type
}
. To enable it, set --enable-multimodal."
)
else
:
enable_multimodal
=
True
# Config draft model
self
.
_config_draft_model
()
# Check model type
self
.
attention_chunk_size
=
getattr
(
self
.
attention_chunk_size
=
getattr
(
self
.
hf_text_config
,
"attention_chunk_size"
,
None
self
.
hf_text_config
,
"attention_chunk_size"
,
None
)
)
...
@@ -127,70 +148,20 @@ class ModelConfig:
...
@@ -127,70 +148,20 @@ class ModelConfig:
self
.
hf_config
.
architectures
,
self
.
hf_text_config
.
num_hidden_layers
self
.
hf_config
.
architectures
,
self
.
hf_text_config
.
num_hidden_layers
)
)
)
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
is_embedding
)
self
.
is_multimodal
=
enable_multimodal
and
is_multimodal_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_gen
=
enable_multimodal
and
is_multimodal_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_image_gen
=
enable_multimodal
and
is_image_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_audio_model
=
enable_multimodal
and
is_audio_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_chunked_prefill_supported
=
(
enable_multimodal
and
is_multimodal_chunked_prefill_supported
(
self
.
hf_config
.
architectures
)
)
self
.
is_encoder_decoder
=
is_encoder_decoder_model
(
self
.
hf_config
.
architectures
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
# Derive context length and model shapes
self
.
_derive_context_length
(
context_length
)
self
.
_derive_model_shapes
()
# Verify quantization
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
self
.
hf_eos_token_id
=
self
.
_get_hf_eos_token_id
()
# multimodal
self
.
image_token_id
=
getattr
(
self
.
hf_config
,
"image_token_id"
,
None
)
or
getattr
(
self
.
hf_config
,
"image_token_index"
,
None
)
@
staticmethod
def
from_server_args
(
server_args
:
ServerArgs
,
model_path
:
str
=
None
,
model_revision
:
str
=
None
,
**
kwargs
,
):
return
ModelConfig
(
model_path
=
model_path
or
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
model_revision
or
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
model_impl
=
server_args
.
model_impl
,
**
kwargs
,
)
def
_config_draft_model
(
self
):
if
enable_multimodal
is
None
:
is_draft_model
=
self
.
is_draft_model
mm_disabled_models
=
[
"Gemma3ForConditionalGeneration"
,
"Llama4ForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
]
if
self
.
hf_config
.
architectures
[
0
]
in
mm_disabled_models
:
enable_multimodal
=
False
logger
.
info
(
f
"Multimodal is disabled for
{
self
.
hf_config
.
model_type
}
. To enable it, set --enable-multimodal."
)
else
:
enable_multimodal
=
True
if
(
if
(
is_draft_model
is_draft_model
...
@@ -225,10 +196,31 @@ class ModelConfig:
...
@@ -225,10 +196,31 @@ class ModelConfig:
self
.
hf_config
.
architectures
[
0
]
=
"Qwen3NextForCausalLMMTP"
self
.
hf_config
.
architectures
[
0
]
=
"Qwen3NextForCausalLMMTP"
self
.
hf_config
.
num_nextn_predict_layers
=
1
self
.
hf_config
.
num_nextn_predict_layers
=
1
def
_derive_context_length
(
self
,
context_length
:
int
):
# Check model type
is_draft_model
=
self
.
is_draft_model
self
.
is_generation
=
is_generation_model
(
derived_context_len
=
get_context_length
(
self
.
hf_text_config
)
self
.
hf_config
.
architectures
,
is_embedding
)
self
.
is_multimodal
=
enable_multimodal
and
is_multimodal_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_gen
=
enable_multimodal
and
is_multimodal_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_image_gen
=
enable_multimodal
and
is_image_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_audio_model
=
enable_multimodal
and
is_audio_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_chunked_prefill_supported
=
(
enable_multimodal
and
is_multimodal_chunked_prefill_supported
(
self
.
hf_config
.
architectures
)
)
self
.
is_encoder_decoder
=
is_encoder_decoder_model
(
self
.
hf_config
.
architectures
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
# Derive context length
derived_context_len
=
get_context_length
(
self
.
hf_text_config
)
if
context_length
is
not
None
:
if
context_length
is
not
None
:
if
context_length
>
derived_context_len
:
if
context_length
>
derived_context_len
:
reason
=
"Target model's"
if
is_draft_model
else
"User-specified"
reason
=
"Target model's"
if
is_draft_model
else
"User-specified"
...
@@ -242,11 +234,6 @@ class ModelConfig:
...
@@ -242,11 +234,6 @@ class ModelConfig:
):
):
logger
.
warning
(
msg
)
logger
.
warning
(
msg
)
self
.
context_len
=
context_length
self
.
context_len
=
context_length
if
is_draft_model
:
self
.
hf_text_config
.
max_position_embeddings
=
context_length
logger
.
warning
(
f
"Overriding the draft model's max_position_embeddings to
{
context_length
}
."
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"
{
msg
}
To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
f
"
{
msg
}
To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
...
@@ -256,10 +243,6 @@ class ModelConfig:
...
@@ -256,10 +243,6 @@ class ModelConfig:
else
:
else
:
self
.
context_len
=
derived_context_len
self
.
context_len
=
derived_context_len
# Transfer context_len to HuggingFace config so models can access it
self
.
hf_config
.
context_len
=
self
.
context_len
def
_derive_model_shapes
(
self
):
# Unify the config keys for hf_text_config
# Unify the config keys for hf_text_config
self
.
head_dim
=
getattr
(
self
.
head_dim
=
getattr
(
self
.
hf_text_config
,
self
.
hf_text_config
,
...
@@ -270,6 +253,7 @@ class ModelConfig:
...
@@ -270,6 +253,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture
# FIXME: temporary special judge for MLA architecture
if
(
if
(
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV32ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLMNextN"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLMNextN"
in
self
.
hf_config
.
architectures
or
"LongcatFlashForCausalLM"
in
self
.
hf_config
.
architectures
or
"LongcatFlashForCausalLM"
in
self
.
hf_config
.
architectures
...
@@ -282,6 +266,11 @@ class ModelConfig:
...
@@ -282,6 +266,11 @@ class ModelConfig:
self
.
qk_nope_head_dim
=
self
.
hf_config
.
qk_nope_head_dim
self
.
qk_nope_head_dim
=
self
.
hf_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
self
.
v_head_dim
=
self
.
hf_config
.
v_head_dim
self
.
v_head_dim
=
self
.
hf_config
.
v_head_dim
self
.
index_head_dim
=
(
get_nsa_index_head_dim
(
self
.
hf_config
)
if
is_deepseek_nsa
(
self
.
hf_config
)
else
None
)
# Handle rope scaling with yarn
# Handle rope scaling with yarn
self
.
scaling
=
1
/
math
.
sqrt
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
)
self
.
scaling
=
1
/
math
.
sqrt
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
)
...
@@ -354,6 +343,45 @@ class ModelConfig:
...
@@ -354,6 +343,45 @@ class ModelConfig:
)
)
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
# Verify quantization
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
# multimodal
self
.
image_token_id
=
getattr
(
self
.
hf_config
,
"image_token_id"
,
None
)
or
getattr
(
self
.
hf_config
,
"image_token_index"
,
None
)
@
staticmethod
def
from_server_args
(
server_args
:
ServerArgs
,
model_path
:
str
=
None
,
model_revision
:
str
=
None
,
**
kwargs
,
):
return
ModelConfig
(
model_path
=
model_path
or
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
model_revision
or
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
model_impl
=
server_args
.
model_impl
,
remote_instance_weight_loader_seed_instance_ip
=
server_args
.
remote_instance_weight_loader_seed_instance_ip
,
remote_instance_weight_loader_seed_instance_service_port
=
server_args
.
remote_instance_weight_loader_seed_instance_service_port
,
remote_instance_weight_loader_send_weights_group_ports
=
server_args
.
remote_instance_weight_loader_send_weights_group_ports
,
**
kwargs
,
)
def
get_total_num_attention_heads
(
self
)
->
int
:
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
num_attention_heads
return
self
.
num_attention_heads
...
@@ -454,31 +482,13 @@ class ModelConfig:
...
@@ -454,31 +482,13 @@ class ModelConfig:
from
huggingface_hub
import
HfApi
from
huggingface_hub
import
HfApi
hf_api
=
HfApi
()
hf_api
=
HfApi
()
if
hf_api
.
file_exists
(
self
.
model_path
,
"hf_quant_config.json"
):
def
check_hf_quant_config
():
return
hf_api
.
file_exists
(
self
.
model_path
,
"hf_quant_config.json"
)
# Retry HF API call up to 3 times
file_exists
=
retry
(
check_hf_quant_config
,
max_retry
=
2
,
initial_delay
=
1.0
,
max_delay
=
5.0
,
)
if
file_exists
:
quant_cfg
=
modelopt_quant_config
quant_cfg
=
modelopt_quant_config
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
logger
.
warning
(
logger
.
warning
(
"Offline mode is enabled, skipping hf_quant_config.json check"
"Offline mode is enabled, skipping hf_quant_config.json check"
)
)
except
Exception
as
e
:
pass
logger
.
warning
(
f
"Failed to check hf_quant_config.json:
{
self
.
model_path
}
{
e
}
"
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
model_path
,
"hf_quant_config.json"
)):
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
model_path
,
"hf_quant_config.json"
)):
quant_config_file
=
os
.
path
.
join
(
quant_config_file
=
os
.
path
.
join
(
...
@@ -606,7 +616,7 @@ class ModelConfig:
...
@@ -606,7 +616,7 @@ class ModelConfig:
"sparse_attention_enabled"
"sparse_attention_enabled"
]
=
True
]
=
True
def
_
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
if
eos_ids
is
not
None
:
if
eos_ids
is
not
None
:
# it can be either int or list of int
# it can be either int or list of int
...
@@ -626,7 +636,7 @@ class ModelConfig:
...
@@ -626,7 +636,7 @@ class ModelConfig:
eos_ids
=
eos_ids
|
generation_eos_ids
eos_ids
=
eos_ids
|
generation_eos_ids
return
eos_ids
return
eos_ids
def
_
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
def
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
"""
"""
Pull the model config files to a temporary
Pull the model config files to a temporary
directory in case of remote.
directory in case of remote.
...
@@ -769,8 +779,6 @@ multimodal_model_archs = [
...
@@ -769,8 +779,6 @@ multimodal_model_archs = [
"Qwen2AudioForConditionalGeneration"
,
"Qwen2AudioForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
"KimiVLForConditionalGeneration"
,
"KimiVLForConditionalGeneration"
,
"InternVLChatModel"
,
"InternVLChatModel"
,
"InternS1ForConditionalGeneration"
,
"InternS1ForConditionalGeneration"
,
...
...
python/sglang/srt/configs/qwen3_vl.py
deleted
100644 → 0
View file @
8f7453e3
from
typing
import
Optional
,
Union
from
transformers
import
PretrainedConfig
from
transformers.modeling_rope_utils
import
rope_config_validation
class
Qwen3VLVisionConfig
(
PretrainedConfig
):
model_type
=
"qwen3_vl"
base_config_key
=
"vision_config"
def
__init__
(
self
,
depth
=
27
,
hidden_size
=
1152
,
hidden_act
=
"gelu_pytorch_tanh"
,
intermediate_size
=
4304
,
num_heads
=
16
,
in_channels
=
3
,
patch_size
=
16
,
spatial_merge_size
=
2
,
temporal_patch_size
=
2
,
out_hidden_size
=
3584
,
num_position_embeddings
=
2304
,
deepstack_visual_indexes
=
[
8
,
16
,
24
],
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
hidden_size
=
hidden_size
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
num_heads
=
num_heads
self
.
in_channels
=
in_channels
self
.
patch_size
=
patch_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
out_hidden_size
=
out_hidden_size
self
.
num_position_embeddings
=
num_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
deepstack_visual_indexes
=
deepstack_visual_indexes
class
Qwen3VLTextConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen3VLModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
>>> # Initializing a Qwen3VL style configuration
>>> configuration = Qwen3VLTextConfig()
>>> # Initializing a model from the Qwen3-VL-7B style configuration
>>> model = Qwen3VLTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen3_vl_text"
base_config_key
=
"text_config"
def
__init__
(
self
,
vocab_size
=
151936
,
hidden_size
=
4096
,
intermediate_size
=
22016
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
head_dim
=
128
,
hidden_act
=
"silu"
,
max_position_embeddings
=
128000
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
5000000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
head_dim
=
head_dim
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
rope_config_validation
(
self
,
ignore_keys
=
{
"mrope_section"
,
"mrope_interleaved"
})
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
class
Qwen3VLConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The start token index to encode the image prompt.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The end token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings.
```python
>>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
>>> # Initializing a Qwen3-VL style configuration
>>> configuration = Qwen3VLConfig()
>>> # Initializing a model from the Qwen3-VL-4B style configuration
>>> model = Qwen3VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen3_vl"
sub_configs
=
{
"vision_config"
:
Qwen3VLVisionConfig
,
"text_config"
:
Qwen3VLTextConfig
,
}
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
text_config
=
None
,
vision_config
=
None
,
image_token_id
=
151655
,
video_token_id
=
151656
,
vision_start_token_id
=
151652
,
vision_end_token_id
=
151653
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
if
isinstance
(
vision_config
,
dict
):
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
](
**
vision_config
)
elif
vision_config
is
None
:
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
]()
if
isinstance
(
text_config
,
dict
):
self
.
text_config
=
self
.
sub_configs
[
"text_config"
](
**
text_config
)
elif
text_config
is
None
:
self
.
text_config
=
self
.
sub_configs
[
"text_config"
]()
self
.
image_token_id
=
image_token_id
self
.
video_token_id
=
video_token_id
self
.
vision_start_token_id
=
vision_start_token_id
self
.
vision_end_token_id
=
vision_end_token_id
super
().
__init__
(
**
kwargs
,
tie_word_embeddings
=
tie_word_embeddings
)
class
Qwen3VLMoeTextConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
head_dim (`int`, *optional*):
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
```python
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
>>> # Initializing a Qwen3VLMoe style configuration
>>> configuration = Qwen3VLMoeConfig()
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen3_vl_moe_text"
base_config_key
=
"text_config"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `Qwen3VLMoe`
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
=
151936
,
hidden_size
=
2048
,
intermediate_size
=
5632
,
num_hidden_layers
=
24
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
hidden_act
=
"silu"
,
max_position_embeddings
=
128000
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
5000000.0
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
decoder_sparse_step
=
1
,
moe_intermediate_size
=
1408
,
num_experts_per_tok
=
4
,
num_experts
=
60
,
norm_topk_prob
=
True
,
mlp_only_layers
=
None
,
rope_scaling
=
None
,
head_dim
=
None
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
rope_scaling
=
rope_scaling
self
.
head_dim
=
head_dim
or
hidden_size
//
num_attention_heads
rope_config_validation
(
self
,
ignore_keys
=
{
"mrope_section"
,
"mrope_interleaved"
})
# MoE arguments
self
.
decoder_sparse_step
=
decoder_sparse_step
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts
=
num_experts
self
.
norm_topk_prob
=
norm_topk_prob
self
.
mlp_only_layers
=
[]
if
mlp_only_layers
is
None
else
mlp_only_layers
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
class
Qwen3VLMoeVisionConfig
(
PretrainedConfig
):
model_type
=
"qwen3_vl_moe"
base_config_key
=
"vision_config"
def
__init__
(
self
,
depth
=
27
,
hidden_size
=
1152
,
hidden_act
=
"gelu_pytorch_tanh"
,
intermediate_size
=
4304
,
num_heads
=
16
,
in_channels
=
3
,
patch_size
=
16
,
spatial_merge_size
=
2
,
temporal_patch_size
=
2
,
out_hidden_size
=
3584
,
num_position_embeddings
=
2304
,
deepstack_visual_indexes
=
[
8
,
16
,
24
],
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
hidden_size
=
hidden_size
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
num_heads
=
num_heads
self
.
in_channels
=
in_channels
self
.
patch_size
=
patch_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
out_hidden_size
=
out_hidden_size
self
.
num_position_embeddings
=
num_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
deepstack_visual_indexes
=
deepstack_visual_indexes
class
Qwen3VLMoeConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The start token index to encode the image prompt.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The end token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings.
```python
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
>>> # Initializing a Qwen3-VL-MOE style configuration
>>> configuration = Qwen3VLMoeConfig()
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen3_vl_moe"
sub_configs
=
{
"vision_config"
:
Qwen3VLMoeVisionConfig
,
"text_config"
:
Qwen3VLMoeTextConfig
,
}
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
text_config
=
None
,
vision_config
=
None
,
image_token_id
=
151655
,
video_token_id
=
151656
,
vision_start_token_id
=
151652
,
vision_end_token_id
=
151653
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
if
isinstance
(
vision_config
,
dict
):
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
](
**
vision_config
)
elif
vision_config
is
None
:
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
]()
if
isinstance
(
text_config
,
dict
):
self
.
text_config
=
self
.
sub_configs
[
"text_config"
](
**
text_config
)
elif
text_config
is
None
:
self
.
text_config
=
self
.
sub_configs
[
"text_config"
]()
self
.
image_token_id
=
image_token_id
self
.
video_token_id
=
video_token_id
self
.
vision_start_token_id
=
vision_start_token_id
self
.
vision_end_token_id
=
vision_end_token_id
super
().
__init__
(
**
kwargs
,
tie_word_embeddings
=
tie_word_embeddings
)
__all__
=
[
"Qwen3VLMoeConfig"
,
"Qwen3VLMoeVisionConfig"
,
"Qwen3VLConfig"
,
"Qwen3VLVisionConfig"
,
]
python/sglang/srt/disaggregation/ascend/transfer_engine.py
View file @
852a49c5
...
@@ -2,9 +2,19 @@ import logging
...
@@ -2,9 +2,19 @@ import logging
import
os
import
os
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
try
:
from
mf_adapter
import
TransferEngine
import_error
=
None
except
ImportError
as
e
:
import_error
=
e
pass
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
...
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
def
__init__
(
def
__init__
(
self
,
hostname
:
str
,
npu_id
:
int
,
disaggregation_mode
:
DisaggregationMode
self
,
hostname
:
str
,
npu_id
:
int
,
disaggregation_mode
:
DisaggregationMode
):
):
try
:
if
import_error
is
not
None
:
from
mf_adapter
import
TransferEngine
logger
.
warning
(
except
ImportError
as
e
:
raise
ImportError
(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
)
from
e
)
raise
import_error
self
.
engine
=
TransferEngine
()
self
.
engine
=
TransferEngine
()
self
.
hostname
=
hostname
self
.
hostname
=
hostname
...
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
...
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
self
.
initialize
()
self
.
initialize
()
def
initialize
(
self
)
->
None
:
def
initialize
(
self
)
->
None
:
from
sglang.srt.layers.dp_attention
import
(
get_tensor_model_parallel_world_size
,
get_tp_group
,
)
transfer_protocol
=
self
.
_get_transfer_protocol
()
if
transfer_protocol
is
None
or
transfer_protocol
==
"sdma"
:
trans_op_type
=
TransferEngine
.
TransDataOpType
.
SDMA
else
:
trans_op_type
=
TransferEngine
.
TransDataOpType
.
DEVICE_RDMA
"""with device RDMA for PD transfer"""
tmp_tensor
=
torch
.
zeros
(
1
,
device
=
"npu"
)
output_tensor_list
=
[
torch
.
empty_like
(
tmp_tensor
)
for
_
in
range
(
get_tensor_model_parallel_world_size
())
]
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
torch
.
distributed
.
all_gather
(
output_tensor_list
,
tmp_tensor
,
group
=
get_tp_group
().
device_group
)
"""Initialize the ascend transfer instance."""
"""Initialize the ascend transfer instance."""
ret_value
=
self
.
engine
.
initialize
(
ret_value
=
self
.
engine
.
initialize
(
self
.
store_url
,
self
.
store_url
,
self
.
session_id
,
self
.
role
,
self
.
npu_id
,
trans_op_type
self
.
session_id
,
self
.
role
,
self
.
npu_id
,
)
)
if
ret_value
!=
0
:
if
ret_value
!=
0
:
logger
.
error
(
"Ascend Transfer Engine initialization failed."
)
logger
.
error
(
"Ascend Transfer Engine initialization failed."
)
...
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
...
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
ret_value
=
-
1
ret_value
=
-
1
if
ret_value
!=
0
:
if
ret_value
!=
0
:
logger
.
debug
(
f
"Ascend memory registration for ptr
{
ptrs
}
failed."
)
logger
.
debug
(
f
"Ascend memory registration for ptr
{
ptrs
}
failed."
)
@
staticmethod
def
_get_transfer_protocol
():
protocol
=
os
.
getenv
(
"ASCEND_MF_TRANSFER_PROTOCOL"
)
allowed_protocols
=
{
"device_rdma"
,
"sdma"
}
if
protocol
and
protocol
.
lower
()
in
allowed_protocols
:
return
protocol
.
lower
()
else
:
logger
.
warning
(
"Invalid or no transfer protocol specified, using default protocol."
)
return
None
\ No newline at end of file
python/sglang/srt/disaggregation/common/conn.py
View file @
852a49c5
...
@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
...
@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
def
_bind_server_socket
(
self
):
def
_bind_server_socket
(
self
):
self
.
server_socket
.
bind
(
format_tcp_address
(
self
.
local_ip
,
self
.
rank_port
))
self
.
server_socket
.
bind
(
format_tcp_address
(
self
.
local_ip
,
self
.
rank_port
))
@
cache
def
_connect
(
self
,
endpoint
:
str
,
is_ipv6
:
bool
=
False
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
if
is_ipv6
:
socket
.
setsockopt
(
zmq
.
IPV6
,
1
)
socket
.
connect
(
endpoint
)
return
socket
def
_register_to_bootstrap
(
self
):
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
"""Register KVSender to bootstrap server via HTTP POST."""
if
self
.
dist_init_addr
:
if
self
.
dist_init_addr
:
...
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
...
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
socket
.
connect
(
endpoint
)
socket
.
connect
(
endpoint
)
return
socket
return
socket
def
get_mha_kv_ptrs_with_pp
(
self
,
src_kv_ptrs
:
List
[
int
],
dst_kv_ptrs
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
int
]:
# pp is not supported on the decode side yet
start_layer
=
self
.
kv_args
.
prefill_start_layer
num_kv_layers
=
len
(
src_kv_ptrs
)
//
2
end_layer
=
start_layer
+
num_kv_layers
dst_num_total_layers
=
len
(
dst_kv_ptrs
)
//
2
src_k_ptrs
=
src_kv_ptrs
[:
num_kv_layers
]
src_v_ptrs
=
src_kv_ptrs
[
num_kv_layers
:]
dst_k_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
dst_v_ptrs
=
dst_kv_ptrs
[
dst_num_total_layers
+
start_layer
:
dst_num_total_layers
+
end_layer
]
layers_current_pp_stage
=
len
(
src_k_ptrs
)
return
src_k_ptrs
,
src_v_ptrs
,
dst_k_ptrs
,
dst_v_ptrs
,
layers_current_pp_stage
def
get_mla_kv_ptrs_with_pp
(
self
,
src_kv_ptrs
:
List
[
int
],
dst_kv_ptrs
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
int
],
int
]:
# pp is not supported on the decode side yet
start_layer
=
self
.
kv_args
.
prefill_start_layer
end_layer
=
start_layer
+
len
(
src_kv_ptrs
)
sliced_dst_kv_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
layers_current_pp_stage
=
len
(
src_kv_ptrs
)
return
src_kv_ptrs
,
sliced_dst_kv_ptrs
,
layers_current_pp_stage
class
CommonKVSender
(
BaseKVSender
):
class
CommonKVSender
(
BaseKVSender
):
...
...
python/sglang/srt/disaggregation/decode.py
View file @
852a49c5
...
@@ -609,21 +609,15 @@ class DecodeTransferQueue:
...
@@ -609,21 +609,15 @@ class DecodeTransferQueue:
idx
=
decode_req
.
metadata_buffer_index
idx
=
decode_req
.
metadata_buffer_index
(
(
output_id
,
output_id
,
cached_tokens
,
output_token_logprobs_val
,
output_token_logprobs_val
,
output_token_logprobs_idx
,
output_token_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
output_top_logprobs_idx
,
output_topk_p
,
output_topk_index
,
output_hidden_states
,
output_hidden_states
,
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
cached_tokens
=
cached_tokens
[
0
].
item
()
if
not
self
.
spec_algorithm
.
is_none
():
if
not
self
.
spec_algorithm
.
is_none
():
decode_req
.
req
.
output_topk_p
=
output_topk_p
decode_req
.
req
.
output_topk_index
=
output_topk_index
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
if
decode_req
.
req
.
return_logprob
:
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
decode_req
.
req
.
output_token_logprobs_val
.
append
(
...
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
elif
prepare_mlp_sync_flag
:
elif
prepare_mlp_sync_flag
:
batch
,
_
=
self
.
_prepare_idle_batch_and_run
(
None
)
batch
,
_
=
self
.
_prepare_idle_batch_and_run
(
None
)
queue_size
=
(
if
batch
is
None
and
(
len
(
self
.
waiting_queue
)
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
)
==
0
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
):
queue_size
+=
len
(
self
.
decode_offload_manager
.
ongoing_offload
)
if
batch
is
None
and
queue_size
==
0
:
self
.
self_check_during_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
)
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
queue_size
=
(
if
batch
is
None
and
(
len
(
self
.
waiting_queue
)
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
)
==
0
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
):
queue_size
+=
len
(
self
.
decode_offload_manager
.
ongoing_offload
)
if
batch
is
None
and
queue_size
==
0
:
self
.
self_check_during_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
self
.
disagg_decode_transfer_queue
.
pop_transferred
()
self
.
disagg_decode_transfer_queue
.
pop_transferred
()
)
# the requests which kv has arrived
)
# the requests which kv has arrived
self
.
waiting_queue
.
extend
(
alloc_reqs
)
self
.
waiting_queue
.
extend
(
alloc_reqs
)
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
self
.
decode_offload_manager
.
check_offload_progress
()
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
deleted
100644 → 0
View file @
8f7453e3
import
logging
import
threading
import
time
import
torch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
)
from
sglang.srt.mem_cache.memory_pool_host
import
(
MHATokenToKVPoolHost
,
MLATokenToKVPoolHost
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
DecodeKVCacheOffloadManager
:
"""Manage decode-side KV cache offloading lifecycle and operations."""
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
BaseTokenToKVPoolAllocator
,
tp_group
:
torch
.
distributed
.
ProcessGroup
,
tree_cache
:
BasePrefixCache
,
server_args
:
ServerArgs
,
)
->
None
:
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
page_size
=
server_args
.
page_size
self
.
server_args
=
server_args
self
.
request_counter
=
0
self
.
tree_cache
=
tree_cache
kv_cache
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
kv_cache
,
MHATokenToKVPool
):
self
.
decode_host_mem_pool
=
MHATokenToKVPoolHost
(
kv_cache
,
server_args
.
hicache_ratio
,
server_args
.
hicache_size
,
self
.
page_size
,
server_args
.
hicache_mem_layout
,
)
elif
isinstance
(
kv_cache
,
MLATokenToKVPool
):
self
.
decode_host_mem_pool
=
MLATokenToKVPoolHost
(
kv_cache
,
server_args
.
hicache_ratio
,
server_args
.
hicache_size
,
self
.
page_size
,
server_args
.
hicache_mem_layout
,
)
else
:
raise
ValueError
(
"Unsupported KV cache type for decode offload"
)
self
.
tp_group
=
tp_group
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
mem_pool_host
=
self
.
decode_host_mem_pool
,
page_size
=
self
.
page_size
,
tp_group
=
tp_group
,
io_backend
=
server_args
.
hicache_io_backend
,
load_cache_event
=
threading
.
Event
(),
storage_backend
=
server_args
.
hicache_storage_backend
,
model_name
=
server_args
.
served_model_name
,
storage_backend_extra_config
=
server_args
.
hicache_storage_backend_extra_config
,
)
self
.
ongoing_offload
=
{}
self
.
ongoing_backup
=
{}
logger
.
info
(
"Enable offload kv cache for decode side"
)
def
offload_kv_cache
(
self
,
req
)
->
bool
:
"""Offload a finished request's KV cache to storage."""
if
self
.
cache_controller
is
None
or
self
.
decode_host_mem_pool
is
None
:
return
False
if
req
.
req_pool_idx
==
-
1
:
return
False
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
]
if
token_indices
.
dim
()
==
0
or
token_indices
.
numel
()
==
0
:
logger
.
debug
(
f
"Request
{
req
.
rid
}
has invalid token_indices:
{
token_indices
}
"
)
return
False
tokens
=
req
.
origin_input_ids
+
req
.
output_ids
aligned_len
=
(
len
(
tokens
)
//
self
.
page_size
)
*
self
.
page_size
if
aligned_len
==
0
:
return
False
token_indices
=
token_indices
[:
aligned_len
]
tokens
=
tokens
[:
aligned_len
]
# Asynchronously offload KV cache from device to host by cache controller
self
.
request_counter
+=
1
ack_id
=
self
.
request_counter
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
token_indices
.
long
(),
node_id
=
ack_id
,
)
if
host_indices
is
None
:
logger
.
error
(
f
"Not enough host memory for request
{
req
.
rid
}
"
)
return
False
self
.
ongoing_offload
[
ack_id
]
=
(
req
,
host_indices
,
tokens
,
time
.
time
())
return
True
def
check_offload_progress
(
self
):
"""Check the progress of offload from device to host and backup from host to storage."""
cc
=
self
.
cache_controller
qsizes
=
torch
.
tensor
(
[
len
(
cc
.
ack_write_queue
),
cc
.
ack_backup_queue
.
qsize
(),
],
dtype
=
torch
.
int
,
)
if
self
.
tp_world_size
>
1
:
torch
.
distributed
.
all_reduce
(
qsizes
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
)
n_write
,
n_backup
=
map
(
int
,
qsizes
.
tolist
())
self
.
_check_offload_progress
(
n_write
)
self
.
_check_backup_progress
(
n_backup
)
def
_check_offload_progress
(
self
,
finish_count
):
"""Check the progress of offload from device to host."""
while
finish_count
>
0
:
_
,
finish_event
,
ack_list
=
self
.
cache_controller
.
ack_write_queue
.
pop
(
0
)
finish_event
.
synchronize
()
for
ack_id
in
ack_list
:
req
,
host_indices
,
tokens
,
start_time
=
self
.
ongoing_offload
.
pop
(
ack_id
)
# Release device
self
.
tree_cache
.
cache_finished_req
(
req
)
# Trigger async backup from host to storage by cache controller
self
.
_trigger_backup
(
req
.
rid
,
host_indices
,
tokens
,
start_time
)
finish_count
-=
1
def
_check_backup_progress
(
self
,
finish_count
):
"""Check the progress of backup from host to storage."""
for
_
in
range
(
finish_count
):
storage_operation
=
self
.
cache_controller
.
ack_backup_queue
.
get
()
ack_id
=
storage_operation
.
id
req_id
,
host_indices
,
start_time
=
self
.
ongoing_backup
.
pop
(
ack_id
)
# Release host memory
self
.
decode_host_mem_pool
.
free
(
host_indices
)
logger
.
debug
(
f
"Finished backup request
{
req_id
}
, free host memory, len:
{
len
(
host_indices
)
}
, cost time:
{
time
.
time
()
-
start_time
:.
2
f
}
seconds."
)
def
_trigger_backup
(
self
,
req_id
,
host_indices
,
tokens
,
start_time
):
"""Trigger async backup from host to storage by cache controller."""
# Generate page hashes and write to storage
page_hashes
=
self
.
_compute_prefix_hash
(
tokens
)
ack_id
=
self
.
cache_controller
.
write_storage
(
host_indices
,
tokens
,
hash_value
=
page_hashes
,
)
self
.
ongoing_backup
[
ack_id
]
=
(
req_id
,
host_indices
,
start_time
)
def
_compute_prefix_hash
(
self
,
tokens
):
last_hash
=
""
page_hashes
=
[]
for
offset
in
range
(
0
,
len
(
tokens
),
self
.
page_size
):
page_tokens
=
tokens
[
offset
:
offset
+
self
.
page_size
]
last_hash
=
self
.
cache_controller
.
get_hash_str
(
page_tokens
,
last_hash
)
page_hashes
.
append
(
last_hash
)
return
page_hashes
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
852a49c5
...
@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
req
.
grammar
.
finished
=
req
.
finished
()
req
.
grammar
.
finished
=
req
.
finished
()
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
# Simulate the eagle run.
# Simulate the eagle run. We add mock data to hidden states for the
if
self
.
spec_algorithm
.
is_eagle
():
# ease of implementation now meaning the first token will have acc rate
# of 0.
if
not
self
.
spec_algorithm
.
is_none
():
b
=
len
(
self
.
reqs
)
b
=
len
(
self
.
reqs
)
topk
=
server_args
.
speculative_eagle_topk
topk_p
=
torch
.
arange
(
topk_p
=
torch
.
stack
(
b
*
server_args
.
speculative_eagle_topk
,
[
0
,
torch
.
as_tensor
(
-
1
,
req
.
output_topk_p
[:
topk
],
device
=
self
.
device
,
device
=
self
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
)
topk_index
=
torch
.
stack
(
topk_p
=
topk_p
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
[
topk_p
/=
b
*
server_args
.
speculative_eagle_topk
torch
.
as_tensor
(
topk_index
=
torch
.
arange
(
req
.
output_topk_index
[:
topk
],
b
*
server_args
.
speculative_eagle_topk
,
device
=
self
.
device
device
=
self
.
device
,
dtype
=
torch
.
int64
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
852a49c5
...
@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
layers_params
=
None
layers_params
=
None
# pp is not supported on the decode side yet
# pp is not supported on the decode side yet
start_layer
=
self
.
kv_args
.
prefill_start_layer
end_layer
=
start_layer
+
len
(
self
.
kv_args
.
kv_data_ptrs
)
if
self
.
is_mla_backend
:
if
self
.
is_mla_backend
:
src_kv_ptrs
,
dst_kv_ptrs
,
layers_current_pp_stage
=
(
src_kv_ptrs
=
self
.
kv_args
.
kv_data_ptrs
self
.
get_mla_kv_ptrs_with_pp
(
self
.
kv_args
.
kv_data_ptrs
,
dst
_kv_ptrs
)
layers_per_pp_stage
=
len
(
src
_kv_ptrs
)
)
dst_kv_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
layers_params
=
[
layers_params
=
[
(
(
...
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
dst_kv_ptrs
[
layer_id
],
dst_kv_ptrs
[
layer_id
],
kv_item_len
,
kv_item_len
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
]
else
:
else
:
src_k_ptrs
,
src_v_ptrs
,
dst_k_ptrs
,
dst_v_ptrs
,
layers_current_pp_stage
=
(
num_kv_layers
=
len
(
self
.
kv_args
.
kv_data_ptrs
)
//
2
self
.
get_mha_kv_ptrs_with_pp
(
self
.
kv_args
.
kv_data_ptrs
,
dst_kv_ptrs
)
dst_num_total_layers
=
num_kv_layers
*
self
.
pp_size
)
src_k_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[:
num_kv_layers
]
src_v_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[
num_kv_layers
:]
layers_per_pp_stage
=
len
(
src_k_ptrs
)
dst_k_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
dst_v_ptrs
=
dst_kv_ptrs
[
dst_num_total_layers
+
start_layer
:
dst_num_total_layers
+
end_layer
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
layers_params
=
[
layers_params
=
[
(
(
...
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
dst_k_ptrs
[
layer_id
],
dst_k_ptrs
[
layer_id
],
kv_item_len
,
kv_item_len
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
+
[
]
+
[
(
(
src_v_ptrs
[
layer_id
],
src_v_ptrs
[
layer_id
],
dst_v_ptrs
[
layer_id
],
dst_v_ptrs
[
layer_id
],
kv_item_len
,
kv_item_len
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
]
assert
layers_params
is
not
None
assert
layers_params
is
not
None
...
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
num_heads_to_send
=
dst_heads_per_rank
num_heads_to_send
=
dst_heads_per_rank
dst_head_start_offset
=
0
dst_head_start_offset
=
0
src_k_ptrs
,
src_v_ptrs
,
dst_k_ptrs
,
dst_v_ptrs
,
layers_current_pp_stage
=
(
# pp is not supported on the decode side yet
self
.
get_mha_kv_ptrs_with_pp
(
self
.
kv_args
.
kv_data_ptrs
,
dst_kv_ptrs
)
num_kv_layers
=
len
(
self
.
kv_args
.
kv_data_ptrs
)
//
2
)
dst_num_total_layers
=
num_kv_layers
*
self
.
pp_size
src_k_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[:
num_kv_layers
]
src_v_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[
num_kv_layers
:]
layers_per_pp_stage
=
len
(
src_k_ptrs
)
start_layer
=
self
.
pp_rank
*
layers_per_pp_stage
end_layer
=
start_layer
+
layers_per_pp_stage
dst_k_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
dst_v_ptrs
=
dst_kv_ptrs
[
dst_num_total_layers
+
start_layer
:
dst_num_total_layers
+
end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset
=
src_head_start_offset
*
bytes_per_head_slice_to_send
src_head_slice_offset
=
src_head_start_offset
*
bytes_per_head_slice_to_send
...
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset
,
dst_head_slice_offset
,
heads_bytes_per_token_to_send
,
heads_bytes_per_token_to_send
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
+
[
]
+
[
(
(
src_v_ptrs
[
layer_id
],
src_v_ptrs
[
layer_id
],
...
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset
,
dst_head_slice_offset
,
heads_bytes_per_token_to_send
,
heads_bytes_per_token_to_send
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
]
def
process_layer_tp_aware
(
layer_params
):
def
process_layer_tp_aware
(
layer_params
):
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
852a49c5
...
@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index
=
(
last_hidden_index
=
(
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
)
)
req
.
output_topk_p
=
batch
.
spec_info
.
topk_p
[
i
]
req
.
output_topk_index
=
batch
.
spec_info
.
topk_index
[
i
]
if
self
.
spec_algorithm
.
is_eagle3
():
if
self
.
spec_algorithm
.
is_eagle3
():
req
.
hidden_states_tensor
=
(
req
.
hidden_states_tensor
=
(
batch
.
spec_info
.
hidden_states
[
i
].
cpu
().
clone
()
batch
.
spec_info
.
hidden_states
[
i
].
cpu
().
clone
()
...
...
python/sglang/srt/disaggregation/utils.py
View file @
852a49c5
...
@@ -85,7 +85,7 @@ class MetadataBuffers:
...
@@ -85,7 +85,7 @@ class MetadataBuffers:
self
,
self
,
size
:
int
,
size
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
hidden_states_
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
max_top_logprobs_num
:
int
=
128
,
max_top_logprobs_num
:
int
=
128
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
):
):
...
@@ -107,9 +107,7 @@ class MetadataBuffers:
...
@@ -107,9 +107,7 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cached_tokens
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
output_token_logprobs_val
=
torch
.
zeros
(
self
.
output_token_logprobs_val
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
)
)
...
@@ -122,49 +120,33 @@ class MetadataBuffers:
...
@@ -122,49 +120,33 @@ class MetadataBuffers:
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
device
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
device
)
)
# For PD + spec decode
self
.
output_topk_p
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
output_topk_index
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
output_hidden_states
=
torch
.
zeros
(
self
.
output_hidden_states
=
torch
.
zeros
(
(
size
,
hidden_size
),
dtype
=
hidden_states_
dtype
,
device
=
device
(
size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
)
)
def
get_buf_infos
(
self
):
def
get_buf_infos
(
self
):
ptrs
=
[
ptrs
=
[
self
.
output_ids
.
data_ptr
(),
self
.
output_ids
.
data_ptr
(),
self
.
cached_tokens
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
self
.
output_topk_p
.
data_ptr
(),
self
.
output_topk_index
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
]
]
data_lens
=
[
data_lens
=
[
self
.
output_ids
.
nbytes
,
self
.
output_ids
.
nbytes
,
self
.
cached_tokens
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
self
.
output_topk_p
.
nbytes
,
self
.
output_topk_index
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
]
]
item_lens
=
[
item_lens
=
[
self
.
output_ids
[
0
].
nbytes
,
self
.
output_ids
[
0
].
nbytes
,
self
.
cached_tokens
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
self
.
output_topk_p
[
0
].
nbytes
,
self
.
output_topk_index
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
]
]
return
ptrs
,
data_lens
,
item_lens
return
ptrs
,
data_lens
,
item_lens
...
@@ -172,20 +154,16 @@ class MetadataBuffers:
...
@@ -172,20 +154,16 @@ class MetadataBuffers:
def
get_buf
(
self
,
idx
:
int
):
def
get_buf
(
self
,
idx
:
int
):
return
(
return
(
self
.
output_ids
[
idx
],
self
.
output_ids
[
idx
],
self
.
cached_tokens
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
self
.
output_topk_p
[
idx
],
self
.
output_topk_index
[
idx
],
self
.
output_hidden_states
[
idx
],
self
.
output_hidden_states
[
idx
],
)
)
def
set_buf
(
self
,
req
:
Req
):
def
set_buf
(
self
,
req
:
Req
):
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
self
.
cached_tokens
[
req
.
metadata_buffer_index
][
0
]
=
req
.
cached_tokens
if
req
.
return_logprob
:
if
req
.
return_logprob
:
if
req
.
output_token_logprobs_val
:
# not none or empty list
if
req
.
output_token_logprobs_val
:
# not none or empty list
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
...
@@ -208,17 +186,8 @@ class MetadataBuffers:
...
@@ -208,17 +186,8 @@ class MetadataBuffers:
]
=
torch
.
tensor
(
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
)
#
F
or PD + spec decode
#
f
or PD + spec decode
if
req
.
hidden_states_tensor
is
not
None
:
if
req
.
hidden_states_tensor
is
not
None
:
# speculative_eagle_topk should not be greater than 16 currently
topk
=
req
.
output_topk_p
.
size
(
0
)
self
.
output_topk_p
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_p
)
self
.
output_topk_index
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_index
)
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
req
.
hidden_states_tensor
req
.
hidden_states_tensor
)
)
...
...
python/sglang/srt/entrypoints/engine.py
View file @
852a49c5
...
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
_is_cuda
and
not
get_bool_env_var
(
"SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"
):
if
_is_cuda
and
not
get_bool_env_var
(
"SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"
):
assert_pkg_version
(
assert_pkg_version
(
"sgl-kernel"
,
"sgl-kernel"
,
"0.3.1
2
"
,
"0.3.1
1
"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
)
)
...
...
Prev
1
2
3
4
5
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment