Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
"git@developer.sourcefind.cn:modelzoo/alphafold2_jax.git" did not exist on "0bab1bf84d9d887aba5cfb6d09af1e8c3ecbc408"
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
295 additions
and
1379 deletions
+295
-1379
python/pyproject.toml
python/pyproject.toml
+4
-4
python/pyproject_other.toml
python/pyproject_other.toml
+5
-5
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+8
-6
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+32
-305
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+1
-7
python/sglang/environ.py
python/sglang/environ.py
+0
-2
python/sglang/global_config.py
python/sglang/global_config.py
+2
-2
python/sglang/launch_server.py
python/sglang/launch_server.py
+0
-14
python/sglang/srt/configs/load_config.py
python/sglang/srt/configs/load_config.py
+0
-8
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+131
-123
python/sglang/srt/configs/qwen3_vl.py
python/sglang/srt/configs/qwen3_vl.py
+0
-586
python/sglang/srt/disaggregation/ascend/transfer_engine.py
python/sglang/srt/disaggregation/ascend/transfer_engine.py
+47
-9
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+8
-27
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+6
-21
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
...lang/srt/disaggregation/decode_kvcache_offload_manager.py
+0
-185
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+15
-23
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+31
-14
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+0
-2
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+4
-35
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
No files found.
python/pyproject.toml
View file @
852a49c5
...
@@ -57,7 +57,7 @@ dependencies = [
...
@@ -57,7 +57,7 @@ dependencies = [
"uvicorn"
,
"uvicorn"
,
"uvloop"
,
"uvloop"
,
"xgrammar==0.1.24"
,
"xgrammar==0.1.24"
,
"sgl-kernel==0.3.1
3
"
,
"sgl-kernel==0.3.1
1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
...
@@ -67,7 +67,7 @@ dependencies = [
...
@@ -67,7 +67,7 @@ dependencies = [
"tiktoken"
,
"tiktoken"
,
"anthropic>=0.20.0"
,
"anthropic>=0.20.0"
,
"torch_memory_saver==0.0.8"
,
"torch_memory_saver==0.0.8"
,
"nvidia-cutlass-dsl==4.2.
1
"
,
"nvidia-cutlass-dsl==4.2.
0
"
,
]
]
[project.optional-dependencies]
[project.optional-dependencies]
...
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
...
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json"
,
"srt/layers/moe/fused_moe_triton/configs/*/*.json"
,
"srt/layers/quantization/configs/*.json"
,
"srt/layers/quantization/configs/*.json"
,
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp"
,
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp"
,
"srt/speculative/cpp_
ngram
/*.cpp"
,
"srt/speculative/cpp_
lookahead
/*.cpp"
,
"srt/speculative/cpp_
ngram
/*.h"
,
"srt/speculative/cpp_
lookahead
/*.h"
,
]
]
[tool.setuptools.packages.find]
[tool.setuptools.packages.find]
...
...
python/pyproject_other.toml
View file @
852a49c5
...
@@ -65,23 +65,23 @@ tracing = [
...
@@ -65,23 +65,23 @@ tracing = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.3.1
3
"
,
"sgl-kernel==0.3.1
1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
"cuda-python"
,
"cuda-python"
,
"flashinfer_python==0.
4.0rc
1"
,
"flashinfer_python==0.
3.
1"
,
]
]
blackwell
=
[
blackwell
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.3.1
3
"
,
"sgl-kernel==0.3.1
1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
"cuda-python"
,
"cuda-python"
,
"flashinfer_python==0.
4.0rc
1"
,
"flashinfer_python==0.
3.
1"
,
"nvidia-cutlass-dsl==4.2.
1
"
,
"nvidia-cutlass-dsl==4.2.
0
"
,
]
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# HIP (Heterogeneous-computing Interface for Portability) for AMD
...
...
python/sglang/bench_one_batch.py
View file @
852a49c5
...
@@ -443,9 +443,11 @@ def latency_test_run_once(
...
@@ -443,9 +443,11 @@ def latency_test_run_once(
if
profile
:
if
profile
:
profiler
.
stop
()
profiler
.
stop
()
trace_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_prefill.trace.json.gz"
profile_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_prefill.trace.json.gz"
_save_profile_trace_results
(
profiler
,
trace_filename
)
_save_profile_trace_results
(
profiler
,
profile_filename
)
rank_print
(
f
"torch profiler chrome trace for prefill saved to
{
trace_filename
}
"
)
rank_print
(
f
"torch profiler chrome trace for prefill saved to
{
profile_filename
}
"
)
# Decode
# Decode
decode_latencies
=
[]
decode_latencies
=
[]
...
@@ -477,10 +479,10 @@ def latency_test_run_once(
...
@@ -477,10 +479,10 @@ def latency_test_run_once(
if
profile
and
i
==
output_len
/
2
:
if
profile
and
i
==
output_len
/
2
:
profiler
.
stop
()
profiler
.
stop
()
trac
e_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_decode.trace.json.gz"
profil
e_filename
=
f
"
{
profile_filename_prefix
}
_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_decode.trace.json.gz"
_save_profile_trace_results
(
profiler
,
trac
e_filename
)
_save_profile_trace_results
(
profiler
,
profil
e_filename
)
rank_print
(
rank_print
(
f
"torch profiler chrome trace for decoding 1 token saved to
{
trac
e_filename
}
"
f
"torch profiler chrome trace for decoding 1 token saved to
{
profil
e_filename
}
"
)
)
# Record decode timing from 2nd output
# Record decode timing from 2nd output
...
...
python/sglang/bench_one_batch_server.py
View file @
852a49c5
...
@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
...
@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
"""
"""
import
argparse
import
argparse
...
@@ -20,17 +19,12 @@ import multiprocessing
...
@@ -20,17 +19,12 @@ import multiprocessing
import
os
import
os
import
random
import
random
import
time
import
time
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
from
pydantic
import
BaseModel
from
sglang.bench_serving
import
(
from
sglang.bench_serving
import
get_tokenizer
,
sample_random_requests
get_tokenizer
,
sample_mmmu_requests
,
sample_random_requests
,
)
from
sglang.profiler
import
run_profile
from
sglang.profiler
import
run_profile
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
...
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
class
ProfileLinks
(
BaseModel
):
"""Pydantic model for profile trace links."""
extend
:
Optional
[
str
]
=
None
decode
:
Optional
[
str
]
=
None
class
BenchmarkResult
(
BaseModel
):
"""Pydantic model for benchmark results table data, for a single isl and osl"""
model_path
:
str
run_name
:
str
batch_size
:
int
input_len
:
int
output_len
:
int
latency
:
float
ttft
:
float
input_throughput
:
float
output_throughput
:
float
overall_throughput
:
float
last_gen_throughput
:
float
acc_length
:
Optional
[
float
]
=
None
profile_links
:
Optional
[
ProfileLinks
]
=
None
@
staticmethod
def
help_str
()
->
str
:
return
f
"""
Note: To view the traces through perfetto-ui, please:
1. open with Google Chrome
2. allow popup
"""
def
to_markdown_row
(
self
,
trace_dir
,
base_url
:
str
=
""
,
relay_base
:
str
=
""
)
->
str
:
"""Convert this benchmark result to a markdown table row."""
# Calculate costs (assuming H100 pricing for now)
hourly_cost_per_gpu
=
2
# $2/hour for one H100
hourly_cost
=
hourly_cost_per_gpu
*
1
# Assuming tp_size = 1 for simplicity
input_util
=
0.7
accept_length
=
(
round
(
self
.
acc_length
,
2
)
if
self
.
acc_length
is
not
None
else
"n/a"
)
itl
=
1
/
(
self
.
output_throughput
/
self
.
batch_size
)
*
1000
input_cost
=
1e6
/
(
self
.
input_throughput
*
input_util
)
/
3600
*
hourly_cost
output_cost
=
1e6
/
self
.
output_throughput
/
3600
*
hourly_cost
def
get_perfetto_relay_link_from_trace_file
(
trace_file
:
str
):
import
os
from
urllib.parse
import
quote
rel_path
=
os
.
path
.
relpath
(
trace_file
,
trace_dir
)
raw_file_link
=
f
"
{
base_url
}
/
{
rel_path
}
"
relay_link
=
(
f
"
{
relay_base
}
?src=
{
quote
(
raw_file_link
,
safe
=
''
)
}
"
if
relay_base
and
quote
else
raw_file_link
)
return
relay_link
# Handle profile links
profile_link
=
"NA | NA"
if
self
.
profile_links
:
if
self
.
profile_links
.
extend
or
self
.
profile_links
.
decode
:
# Create a combined link or use the first available one
trace_files
=
[
self
.
profile_links
.
extend
,
self
.
profile_links
.
decode
]
trace_files_relay_links
=
[
f
"[trace](
{
get_perfetto_relay_link_from_trace_file
(
trace_file
)
}
)"
for
trace_file
in
trace_files
]
profile_link
=
" | "
.
join
(
trace_files_relay_links
)
# Build the row
return
f
"|
{
self
.
batch_size
}
|
{
self
.
input_len
}
|
{
self
.
latency
:.
2
f
}
|
{
self
.
input_throughput
:.
2
f
}
|
{
self
.
output_throughput
:.
2
f
}
|
{
accept_length
}
|
{
itl
:.
2
f
}
|
{
input_cost
:.
2
f
}
|
{
output_cost
:.
2
f
}
|
{
profile_link
}
|
\n
"
@
classmethod
def
generate_markdown_report
(
cls
,
trace_dir
,
results
:
List
[
"BenchmarkResult"
]
)
->
str
:
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
import
os
summary
=
f
"###
{
results
[
0
].
model_path
}
\n
"
# summary += (
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
# )
summary
+=
"| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|
\n
"
summary
+=
"| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |
\n
"
# all results should share the same isl & osl
for
result
in
results
:
base_url
=
os
.
getenv
(
"TRACE_BASE_URL"
,
""
).
rstrip
(
"/"
)
relay_base
=
os
.
getenv
(
"PERFETTO_RELAY_URL"
,
""
).
rstrip
(
"/"
)
relay_base
=
"https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
# base_url = "https://github.com/sgl-project/ci-data/traces"
summary
+=
result
.
to_markdown_row
(
trace_dir
,
base_url
,
relay_base
)
return
summary
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BenchArgs
:
class
BenchArgs
:
run_name
:
str
=
"default"
run_name
:
str
=
"default"
...
@@ -158,12 +50,8 @@ class BenchArgs:
...
@@ -158,12 +50,8 @@ class BenchArgs:
profile
:
bool
=
False
profile
:
bool
=
False
profile_steps
:
int
=
3
profile_steps
:
int
=
3
profile_by_stage
:
bool
=
False
profile_by_stage
:
bool
=
False
profile_filename_prefix
:
str
=
None
append_to_github_summary
:
bool
=
True
dataset_path
:
str
=
""
dataset_path
:
str
=
""
parallel_batch
:
bool
=
False
parallel_batch
:
bool
=
False
dataset_name
:
str
=
"random"
output_path
:
Optional
[
str
]
=
None
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -179,13 +67,6 @@ class BenchArgs:
...
@@ -179,13 +67,6 @@ class BenchArgs:
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
)
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--dataset-name"
,
type
=
str
,
default
=
BenchArgs
.
dataset_name
,
choices
=
[
"mmmu"
,
"random"
],
help
=
"Name of the dataset to benchmark on."
,
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--client-stream-interval"
,
"--client-stream-interval"
,
...
@@ -215,36 +96,14 @@ class BenchArgs:
...
@@ -215,36 +96,14 @@ class BenchArgs:
help
=
"Path to the dataset."
,
help
=
"Path to the dataset."
,
)
)
parser
.
add_argument
(
"--parallel-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--parallel-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--profile-filename-prefix"
,
type
=
str
,
default
=
BenchArgs
.
profile_filename_prefix
,
)
parser
.
add_argument
(
"--no-append-to-github-summary"
,
action
=
"store_false"
,
dest
=
"append_to_github_summary"
,
help
=
"Disable appending the output of this run to github ci summary"
,
)
parser
.
add_argument
(
"--output-path"
,
type
=
str
,
default
=
BenchArgs
.
output_path
,
help
=
"Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
# use the default value's type to cast the args into correct types.
# use the default value's type to cast the args into correct types.
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
kwargs
=
{}
return
cls
(
for
attr
,
attr_type
in
attrs
:
**
{
attr
:
attr_type
(
getattr
(
args
,
attr
))
for
attr
,
attr_type
in
attrs
}
val
=
getattr
(
args
,
attr
)
)
if
attr_type
is
type
(
None
):
kwargs
[
attr
]
=
val
else
:
kwargs
[
attr
]
=
attr_type
(
val
)
return
cls
(
**
kwargs
)
def
launch_server_internal
(
server_args
):
def
launch_server_internal
(
server_args
):
...
@@ -289,35 +148,23 @@ def run_one_case(
...
@@ -289,35 +148,23 @@ def run_one_case(
run_name
:
str
,
run_name
:
str
,
result_filename
:
str
,
result_filename
:
str
,
tokenizer
,
tokenizer
,
dataset_name
=
""
,
profile
:
bool
=
False
,
profile
:
bool
=
False
,
profile_steps
:
int
=
3
,
profile_steps
:
int
=
3
,
profile_by_stage
:
bool
=
False
,
profile_by_stage
:
bool
=
False
,
profile_filename_prefix
:
str
=
None
,
dataset_path
:
str
=
""
,
dataset_path
:
str
=
""
,
parallel_batch
:
bool
=
False
,
parallel_batch
:
bool
=
False
,
):
):
requests
.
post
(
url
+
"/flush_cache"
)
requests
.
post
(
url
+
"/flush_cache"
)
# TODO: reuse bench_serving.get_dataset ?
input_requests
=
sample_random_requests
(
if
dataset_name
==
"mmmu"
:
input_len
=
input_len
,
input_requests
=
sample_mmmu_requests
(
output_len
=
output_len
,
num_requests
=
batch_size
,
num_prompts
=
batch_size
,
tokenizer
=
tokenizer
,
range_ratio
=
1.0
,
fixed_output_len
=
output_len
,
tokenizer
=
tokenizer
,
apply_chat_template
=
True
,
dataset_path
=
dataset_path
,
random_sample
=
False
,
random_sample
=
True
,
)
return_text
=
False
,
elif
dataset_name
==
"random"
:
)
input_requests
=
sample_random_requests
(
input_len
=
input_len
,
output_len
=
output_len
,
num_prompts
=
batch_size
,
range_ratio
=
1.0
,
tokenizer
=
tokenizer
,
dataset_path
=
dataset_path
,
random_sample
=
True
,
return_text
=
False
,
)
use_structured_outputs
=
False
use_structured_outputs
=
False
if
use_structured_outputs
:
if
use_structured_outputs
:
...
@@ -334,48 +181,26 @@ def run_one_case(
...
@@ -334,48 +181,26 @@ def run_one_case(
profile_link
=
None
profile_link
=
None
if
profile
:
if
profile
:
output_dir
,
profile_name
=
None
,
None
if
profile_filename_prefix
:
output_dir
=
os
.
path
.
dirname
(
profile_filename_prefix
)
profile_name
=
os
.
path
.
basename
(
profile_filename_prefix
)
profile_link
:
str
=
run_profile
(
profile_link
:
str
=
run_profile
(
url
,
url
,
profile_steps
,
[
"CPU"
,
"GPU"
],
None
,
None
,
profile_by_stage
profile_steps
,
[
"CPU"
,
"GPU"
],
output_dir
,
profile_name
,
profile_by_stage
,
)
)
tic
=
time
.
perf_counter
()
tic
=
time
.
perf_counter
()
payload
=
{
"sampling_params"
:
{
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
"stream_interval"
:
stream_interval
,
},
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
**
({
"parallel_batch"
:
parallel_batch
}
if
parallel_batch
else
{}),
}
if
dataset_name
==
"mmmu"
:
# vlm
input_ids
=
[]
for
input_req
in
input_requests
:
input_ids
+=
[
tokenizer
.
encode
(
input_req
.
prompt
)]
payload
[
"image_data"
]
=
[
req
.
image_data
for
req
in
input_requests
]
else
:
input_ids
=
[
req
.
prompt
for
req
in
input_requests
]
payload
[
"input_ids"
]
=
input_ids
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
payload
,
json
=
{
"input_ids"
:
[
req
.
prompt
for
req
in
input_requests
],
"sampling_params"
:
{
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
"stream_interval"
:
stream_interval
,
},
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
**
({
"parallel_batch"
:
parallel_batch
}
if
parallel_batch
else
{}),
},
stream
=
True
,
stream
=
True
,
)
)
...
@@ -439,100 +264,10 @@ def run_one_case(
...
@@ -439,100 +264,10 @@ def run_one_case(
overall_throughput
,
overall_throughput
,
last_gen_throughput
,
last_gen_throughput
,
acc_length
,
acc_length
,
profile_link
,
profile_link
if
profile
else
None
,
)
)
def
save_results_as_json
(
result
:
List
[
Tuple
],
bench_args
:
BenchArgs
,
model
:
str
):
"""Save benchmark results as JSON using Pydantic models."""
json_results
=
[]
# Generate all parameter combinations to match with results
param_combinations
=
list
(
itertools
.
product
(
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
)
)
for
i
,
(
batch_size
,
latency
,
ttft
,
input_throughput
,
output_throughput
,
overall_throughput
,
last_gen_throughput
,
acc_length
,
profile_link
,
)
in
enumerate
(
result
):
# Get the corresponding parameters for this result
bs
,
input_len
,
output_len
=
param_combinations
[
i
]
# Parse profile links if available
profile_links
=
None
if
profile_link
:
profile_links
=
parse_profile_links
(
profile_link
,
batch_size
,
input_len
,
output_len
)
benchmark_result
=
BenchmarkResult
(
model_path
=
model
,
run_name
=
bench_args
.
run_name
,
batch_size
=
batch_size
,
input_len
=
input_len
,
output_len
=
output_len
,
latency
=
latency
,
ttft
=
ttft
,
input_throughput
=
input_throughput
,
output_throughput
=
output_throughput
,
overall_throughput
=
overall_throughput
,
last_gen_throughput
=
last_gen_throughput
,
acc_length
=
acc_length
,
profile_links
=
profile_links
,
)
json_results
.
append
(
benchmark_result
.
model_dump
())
# Save to JSON file
with
open
(
bench_args
.
output_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
json_results
,
f
,
indent
=
2
,
ensure_ascii
=
False
)
print
(
f
"Results saved as JSON to
{
bench_args
.
output_path
}
"
)
def
parse_profile_links
(
profile_dir
:
str
,
batch_size
:
int
,
input_len
:
int
,
output_len
:
int
)
->
Optional
[
ProfileLinks
]:
"""Parse profile directory to extract extend and decode trace file links."""
if
not
profile_dir
or
not
os
.
path
.
exists
(
profile_dir
):
return
None
extend_link
=
None
decode_link
=
None
# Look for extend/prefill trace files
for
file
in
os
.
listdir
(
profile_dir
):
if
file
.
endswith
(
".trace.json.gz"
)
or
file
.
endswith
(
".trace.json"
):
if
"extend"
in
file
.
lower
()
or
"prefill"
in
file
.
lower
():
extend_link
=
os
.
path
.
join
(
profile_dir
,
file
)
elif
"decode"
in
file
.
lower
():
decode_link
=
os
.
path
.
join
(
profile_dir
,
file
)
# If no specific extend/decode files found, try to find files with batch/input/output info
if
not
extend_link
or
not
decode_link
:
for
file
in
os
.
listdir
(
profile_dir
):
if
file
.
endswith
(
".trace.json.gz"
)
or
file
.
endswith
(
".trace.json"
):
if
f
"_batch
{
batch_size
}
_input
{
input_len
}
_output
{
output_len
}
_"
in
file
:
if
"prefill"
in
file
.
lower
()
or
"extend"
in
file
.
lower
():
extend_link
=
os
.
path
.
join
(
profile_dir
,
file
)
elif
"decode"
in
file
.
lower
():
decode_link
=
os
.
path
.
join
(
profile_dir
,
file
)
if
extend_link
or
decode_link
:
return
ProfileLinks
(
extend
=
extend_link
,
decode
=
decode_link
)
return
None
def
get_report_summary
(
def
get_report_summary
(
result
:
List
[
Tuple
],
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
result
:
List
[
Tuple
],
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
):
):
...
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
return_logprob
=
bench_args
.
return_logprob
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
dataset_name
=
bench_args
.
dataset_name
,
run_name
=
""
,
run_name
=
""
,
result_filename
=
""
,
result_filename
=
""
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
stream_interval
=
bench_args
.
client_stream_interval
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
dataset_name
=
bench_args
.
dataset_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
dataset_path
=
bench_args
.
dataset_path
,
dataset_path
=
bench_args
.
dataset_path
,
parallel_batch
=
bench_args
.
parallel_batch
,
parallel_batch
=
bench_args
.
parallel_batch
,
profile_filename_prefix
=
bench_args
.
profile_filename_prefix
,
)
)
)
)
...
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
dataset_name
=
bench_args
.
dataset_name
,
profile
=
bench_args
.
profile
,
profile
=
bench_args
.
profile
,
profile_steps
=
bench_args
.
profile_steps
,
profile_steps
=
bench_args
.
profile_steps
,
profile_by_stage
=
bench_args
.
profile_by_stage
,
profile_by_stage
=
bench_args
.
profile_by_stage
,
dataset_path
=
bench_args
.
dataset_path
,
dataset_path
=
bench_args
.
dataset_path
,
parallel_batch
=
bench_args
.
parallel_batch
,
parallel_batch
=
bench_args
.
parallel_batch
,
profile_filename_prefix
=
bench_args
.
profile_filename_prefix
,
)[
-
1
],
)[
-
1
],
)
)
)
)
...
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
# Save results as JSON if output_path is specified
if
bench_args
.
output_path
:
save_results_as_json
(
result
,
bench_args
,
model
=
server_args
.
model_path
)
if
not
bench_args
.
show_report
:
if
not
bench_args
.
show_report
:
return
return
summary
=
get_report_summary
(
result
,
server_args
,
bench_args
)
summary
=
get_report_summary
(
result
,
server_args
,
bench_args
)
print
(
summary
)
if
is_in_ci
()
and
bench_args
.
append_to_github_summary
:
if
is_in_ci
():
write_github_step_summary
(
summary
)
write_github_step_summary
(
summary
)
...
...
python/sglang/bench_serving.py
View file @
852a49c5
...
@@ -208,10 +208,6 @@ async def async_request_openai_completions(
...
@@ -208,10 +208,6 @@ async def async_request_openai_completions(
"ignore_eos"
:
not
args
.
disable_ignore_eos
,
"ignore_eos"
:
not
args
.
disable_ignore_eos
,
**
request_func_input
.
extra_request_body
,
**
request_func_input
.
extra_request_body
,
}
}
if
request_func_input
.
image_data
:
payload
.
update
({
"image_data"
:
request_func_input
.
image_data
})
headers
=
get_auth_headers
()
headers
=
get_auth_headers
()
output
=
RequestFuncOutput
.
init_new
(
request_func_input
)
output
=
RequestFuncOutput
.
init_new
(
request_func_input
)
...
@@ -1763,9 +1759,7 @@ async def benchmark(
...
@@ -1763,9 +1759,7 @@ async def benchmark(
pbar
.
close
()
pbar
.
close
()
if
"sglang"
in
backend
:
if
"sglang"
in
backend
:
server_info
=
requests
.
get
(
server_info
=
requests
.
get
(
base_url
+
"/get_server_info"
)
base_url
+
"/get_server_info"
,
headers
=
get_auth_headers
()
)
if
server_info
.
status_code
==
200
:
if
server_info
.
status_code
==
200
:
server_info_json
=
server_info
.
json
()
server_info_json
=
server_info
.
json
()
if
"decode"
in
server_info_json
:
if
"decode"
in
server_info_json
:
...
...
python/sglang/
srt/
environ.py
→
python/sglang/environ.py
View file @
852a49c5
...
@@ -124,8 +124,6 @@ class Envs:
...
@@ -124,8 +124,6 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS
=
EnvBool
(
False
)
SGLANG_TEST_REQUEST_TIME_STATS
=
EnvBool
(
False
)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK
=
EnvBool
(
False
)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK
=
EnvBool
(
False
)
SGLANG_DISABLE_REQUEST_LOGGING
=
EnvBool
(
False
)
SGLANG_DISABLE_REQUEST_LOGGING
=
EnvBool
(
False
)
SGLANG_SIMULATE_ACC_LEN
=
EnvFloat
(
-
1
)
SGLANG_SIMULATE_ACC_METHOD
=
EnvStr
(
"multinomial"
)
# Model Parallel
# Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER
=
EnvBool
(
True
)
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER
=
EnvBool
(
True
)
...
...
python/sglang/global_config.py
View file @
852a49c5
...
@@ -37,8 +37,8 @@ class GlobalConfig:
...
@@ -37,8 +37,8 @@ class GlobalConfig:
)
)
# Runtime constants: others
# Runtime constants: others
self
.
retract_decode_steps
=
20
self
.
retract_decode_steps
=
20
self
.
flashinfer_workspace_size
=
in
t
(
self
.
flashinfer_workspace_size
=
os
.
environ
.
ge
t
(
os
.
environ
.
get
(
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
)
"FLASHINFER_WORKSPACE_SIZE"
,
384
*
1024
*
1024
)
)
# Output tokenization configs
# Output tokenization configs
...
...
python/sglang/launch_server.py
View file @
852a49c5
...
@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
...
@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
MOVE_ENVS_WARN
=
"""
########################################################################
# For contributors and developers: #
# Please move environment variable definitions to sglang.srt.environ #
# using the following pattern: #
# SGLANG_XXX = EnvBool(False) #
# #
########################################################################
"""
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
from
sglang.srt.server_args
import
print_deprecated_warning
print_deprecated_warning
(
MOVE_ENVS_WARN
)
try
:
try
:
launch_server
(
server_args
)
launch_server
(
server_args
)
finally
:
finally
:
...
...
python/sglang/srt/configs/load_config.py
View file @
852a49c5
...
@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
...
@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
JAX
=
"jax"
JAX
=
"jax"
REMOTE
=
"remote"
REMOTE
=
"remote"
REMOTE_INSTANCE
=
"remote_instance"
REMOTE_INSTANCE
=
"remote_instance"
RDMA
=
"rdma"
LOCAL_CACHED
=
"local_cached"
@
dataclass
@
dataclass
...
@@ -49,7 +47,6 @@ class LoadConfig:
...
@@ -49,7 +47,6 @@ class LoadConfig:
checkpoints.
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
from this file (after PBKDF2).
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
"""
"""
load_format
:
Union
[
str
,
LoadFormat
]
=
LoadFormat
.
AUTO
load_format
:
Union
[
str
,
LoadFormat
]
=
LoadFormat
.
AUTO
...
@@ -57,11 +54,6 @@ class LoadConfig:
...
@@ -57,11 +54,6 @@ class LoadConfig:
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
decryption_key_file
:
Optional
[
str
]
=
None
decryption_key_file
:
Optional
[
str
]
=
None
decrypt_max_concurrency
:
int
=
-
1
tp_rank
:
Optional
[
int
]
=
None
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
...
...
python/sglang/srt/configs/model_config.py
View file @
852a49c5
...
@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
)
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
retry
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
from
sglang.utils
import
is_in_ci
from
sglang.utils
import
is_in_ci
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
...
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS
=
"transformers"
TRANSFORMERS
=
"transformers"
def
is_deepseek_nsa
(
config
:
PretrainedConfig
)
->
bool
:
return
(
config
.
architectures
is
not
None
and
config
.
architectures
[
0
]
in
[
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
]
and
getattr
(
config
,
"index_topk"
,
None
)
is
not
None
)
def
get_nsa_index_head_dim
(
config
:
PretrainedConfig
)
->
int
:
assert
is_deepseek_nsa
(
config
)
return
config
.
index_head_dim
def
get_nsa_index_topk
(
config
:
PretrainedConfig
)
->
int
:
assert
is_deepseek_nsa
(
config
)
return
config
.
index_topk
def
get_nsa_index_n_heads
(
config
:
PretrainedConfig
)
->
int
:
assert
is_deepseek_nsa
(
config
)
return
config
.
index_n_heads
class
ModelConfig
:
class
ModelConfig
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -64,20 +88,35 @@ class ModelConfig:
...
@@ -64,20 +88,35 @@ class ModelConfig:
is_draft_model
:
bool
=
False
,
is_draft_model
:
bool
=
False
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
tp_rank
:
Optional
[
int
]
=
None
,
remote_instance_weight_loader_seed_instance_ip
:
Optional
[
str
]
=
None
,
remote_instance_weight_loader_seed_instance_service_port
:
Optional
[
int
]
=
None
,
remote_instance_weight_loader_send_weights_group_ports
:
Optional
[
List
[
int
]
]
=
None
,
)
->
None
:
)
->
None
:
# Parse args
# Parse args
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
revision
=
revision
self
.
revision
=
revision
self
.
quantization
=
quantization
self
.
quantization
=
quantization
self
.
is_draft_model
=
is_draft_model
self
.
model_impl
=
model_impl
self
.
model_impl
=
model_impl
self
.
tp_rank
=
tp_rank
self
.
remote_instance_weight_loader_seed_instance_ip
=
(
remote_instance_weight_loader_seed_instance_ip
)
self
.
remote_instance_weight_loader_seed_instance_service_port
=
(
remote_instance_weight_loader_seed_instance_service_port
)
self
.
remote_instance_weight_loader_send_weights_group_ports
=
(
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config
self
.
maybe_pull_model_tokenizer_from_remote
()
self
.
_maybe_pull_model_tokenizer_from_remote
()
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
kwargs
=
{}
kwargs
=
{}
if
override_config_file
and
override_config_file
.
strip
():
if
override_config_file
and
override_config_file
.
strip
():
kwargs
[
"_configuration_file"
]
=
override_config_file
.
strip
()
kwargs
[
"_configuration_file"
]
=
override_config_file
.
strip
()
self
.
hf_config
=
get_config
(
self
.
hf_config
=
get_config
(
self
.
model_path
,
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -85,7 +124,7 @@ class ModelConfig:
...
@@ -85,7 +124,7 @@ class ModelConfig:
model_override_args
=
self
.
model_override_args
,
model_override_args
=
self
.
model_override_args
,
**
kwargs
,
**
kwargs
,
)
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_generation_config
=
get_generation_config
(
self
.
hf_generation_config
=
get_generation_config
(
self
.
model_path
,
self
.
model_path
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -93,25 +132,7 @@ class ModelConfig:
...
@@ -93,25 +132,7 @@ class ModelConfig:
**
kwargs
,
**
kwargs
,
)
)
# Set enable_multimodal
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
if
enable_multimodal
is
None
:
mm_disabled_models
=
[
"Gemma3ForConditionalGeneration"
,
"Llama4ForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
]
if
self
.
hf_config
.
architectures
[
0
]
in
mm_disabled_models
:
enable_multimodal
=
False
logger
.
info
(
f
"Multimodal is disabled for
{
self
.
hf_config
.
model_type
}
. To enable it, set --enable-multimodal."
)
else
:
enable_multimodal
=
True
# Config draft model
self
.
_config_draft_model
()
# Check model type
self
.
attention_chunk_size
=
getattr
(
self
.
attention_chunk_size
=
getattr
(
self
.
hf_text_config
,
"attention_chunk_size"
,
None
self
.
hf_text_config
,
"attention_chunk_size"
,
None
)
)
...
@@ -127,70 +148,20 @@ class ModelConfig:
...
@@ -127,70 +148,20 @@ class ModelConfig:
self
.
hf_config
.
architectures
,
self
.
hf_text_config
.
num_hidden_layers
self
.
hf_config
.
architectures
,
self
.
hf_text_config
.
num_hidden_layers
)
)
)
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
is_embedding
)
self
.
is_multimodal
=
enable_multimodal
and
is_multimodal_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_gen
=
enable_multimodal
and
is_multimodal_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_image_gen
=
enable_multimodal
and
is_image_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_audio_model
=
enable_multimodal
and
is_audio_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_chunked_prefill_supported
=
(
enable_multimodal
and
is_multimodal_chunked_prefill_supported
(
self
.
hf_config
.
architectures
)
)
self
.
is_encoder_decoder
=
is_encoder_decoder_model
(
self
.
hf_config
.
architectures
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
# Derive context length and model shapes
self
.
_derive_context_length
(
context_length
)
self
.
_derive_model_shapes
()
# Verify quantization
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
self
.
hf_eos_token_id
=
self
.
_get_hf_eos_token_id
()
# multimodal
self
.
image_token_id
=
getattr
(
self
.
hf_config
,
"image_token_id"
,
None
)
or
getattr
(
self
.
hf_config
,
"image_token_index"
,
None
)
@
staticmethod
def
from_server_args
(
server_args
:
ServerArgs
,
model_path
:
str
=
None
,
model_revision
:
str
=
None
,
**
kwargs
,
):
return
ModelConfig
(
model_path
=
model_path
or
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
model_revision
or
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
model_impl
=
server_args
.
model_impl
,
**
kwargs
,
)
def
_config_draft_model
(
self
):
if
enable_multimodal
is
None
:
is_draft_model
=
self
.
is_draft_model
mm_disabled_models
=
[
"Gemma3ForConditionalGeneration"
,
"Llama4ForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
]
if
self
.
hf_config
.
architectures
[
0
]
in
mm_disabled_models
:
enable_multimodal
=
False
logger
.
info
(
f
"Multimodal is disabled for
{
self
.
hf_config
.
model_type
}
. To enable it, set --enable-multimodal."
)
else
:
enable_multimodal
=
True
if
(
if
(
is_draft_model
is_draft_model
...
@@ -225,10 +196,31 @@ class ModelConfig:
...
@@ -225,10 +196,31 @@ class ModelConfig:
self
.
hf_config
.
architectures
[
0
]
=
"Qwen3NextForCausalLMMTP"
self
.
hf_config
.
architectures
[
0
]
=
"Qwen3NextForCausalLMMTP"
self
.
hf_config
.
num_nextn_predict_layers
=
1
self
.
hf_config
.
num_nextn_predict_layers
=
1
def
_derive_context_length
(
self
,
context_length
:
int
):
# Check model type
is_draft_model
=
self
.
is_draft_model
self
.
is_generation
=
is_generation_model
(
derived_context_len
=
get_context_length
(
self
.
hf_text_config
)
self
.
hf_config
.
architectures
,
is_embedding
)
self
.
is_multimodal
=
enable_multimodal
and
is_multimodal_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_gen
=
enable_multimodal
and
is_multimodal_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_image_gen
=
enable_multimodal
and
is_image_gen_model
(
self
.
hf_config
.
architectures
)
self
.
is_audio_model
=
enable_multimodal
and
is_audio_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal_chunked_prefill_supported
=
(
enable_multimodal
and
is_multimodal_chunked_prefill_supported
(
self
.
hf_config
.
architectures
)
)
self
.
is_encoder_decoder
=
is_encoder_decoder_model
(
self
.
hf_config
.
architectures
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
# Derive context length
derived_context_len
=
get_context_length
(
self
.
hf_text_config
)
if
context_length
is
not
None
:
if
context_length
is
not
None
:
if
context_length
>
derived_context_len
:
if
context_length
>
derived_context_len
:
reason
=
"Target model's"
if
is_draft_model
else
"User-specified"
reason
=
"Target model's"
if
is_draft_model
else
"User-specified"
...
@@ -242,11 +234,6 @@ class ModelConfig:
...
@@ -242,11 +234,6 @@ class ModelConfig:
):
):
logger
.
warning
(
msg
)
logger
.
warning
(
msg
)
self
.
context_len
=
context_length
self
.
context_len
=
context_length
if
is_draft_model
:
self
.
hf_text_config
.
max_position_embeddings
=
context_length
logger
.
warning
(
f
"Overriding the draft model's max_position_embeddings to
{
context_length
}
."
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"
{
msg
}
To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
f
"
{
msg
}
To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
...
@@ -256,10 +243,6 @@ class ModelConfig:
...
@@ -256,10 +243,6 @@ class ModelConfig:
else
:
else
:
self
.
context_len
=
derived_context_len
self
.
context_len
=
derived_context_len
# Transfer context_len to HuggingFace config so models can access it
self
.
hf_config
.
context_len
=
self
.
context_len
def
_derive_model_shapes
(
self
):
# Unify the config keys for hf_text_config
# Unify the config keys for hf_text_config
self
.
head_dim
=
getattr
(
self
.
head_dim
=
getattr
(
self
.
hf_text_config
,
self
.
hf_text_config
,
...
@@ -270,6 +253,7 @@ class ModelConfig:
...
@@ -270,6 +253,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture
# FIXME: temporary special judge for MLA architecture
if
(
if
(
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV32ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLMNextN"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLMNextN"
in
self
.
hf_config
.
architectures
or
"LongcatFlashForCausalLM"
in
self
.
hf_config
.
architectures
or
"LongcatFlashForCausalLM"
in
self
.
hf_config
.
architectures
...
@@ -282,6 +266,11 @@ class ModelConfig:
...
@@ -282,6 +266,11 @@ class ModelConfig:
self
.
qk_nope_head_dim
=
self
.
hf_config
.
qk_nope_head_dim
self
.
qk_nope_head_dim
=
self
.
hf_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
self
.
v_head_dim
=
self
.
hf_config
.
v_head_dim
self
.
v_head_dim
=
self
.
hf_config
.
v_head_dim
self
.
index_head_dim
=
(
get_nsa_index_head_dim
(
self
.
hf_config
)
if
is_deepseek_nsa
(
self
.
hf_config
)
else
None
)
# Handle rope scaling with yarn
# Handle rope scaling with yarn
self
.
scaling
=
1
/
math
.
sqrt
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
)
self
.
scaling
=
1
/
math
.
sqrt
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
)
...
@@ -354,6 +343,45 @@ class ModelConfig:
...
@@ -354,6 +343,45 @@ class ModelConfig:
)
)
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
# Verify quantization
self
.
_verify_quantization
()
# Verify dual-chunk attention config
self
.
_verify_dual_chunk_attention_config
()
# Cache attributes
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
# multimodal
self
.
image_token_id
=
getattr
(
self
.
hf_config
,
"image_token_id"
,
None
)
or
getattr
(
self
.
hf_config
,
"image_token_index"
,
None
)
@
staticmethod
def
from_server_args
(
server_args
:
ServerArgs
,
model_path
:
str
=
None
,
model_revision
:
str
=
None
,
**
kwargs
,
):
return
ModelConfig
(
model_path
=
model_path
or
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
model_revision
or
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
model_impl
=
server_args
.
model_impl
,
remote_instance_weight_loader_seed_instance_ip
=
server_args
.
remote_instance_weight_loader_seed_instance_ip
,
remote_instance_weight_loader_seed_instance_service_port
=
server_args
.
remote_instance_weight_loader_seed_instance_service_port
,
remote_instance_weight_loader_send_weights_group_ports
=
server_args
.
remote_instance_weight_loader_send_weights_group_ports
,
**
kwargs
,
)
def
get_total_num_attention_heads
(
self
)
->
int
:
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
num_attention_heads
return
self
.
num_attention_heads
...
@@ -454,31 +482,13 @@ class ModelConfig:
...
@@ -454,31 +482,13 @@ class ModelConfig:
from
huggingface_hub
import
HfApi
from
huggingface_hub
import
HfApi
hf_api
=
HfApi
()
hf_api
=
HfApi
()
if
hf_api
.
file_exists
(
self
.
model_path
,
"hf_quant_config.json"
):
def
check_hf_quant_config
():
return
hf_api
.
file_exists
(
self
.
model_path
,
"hf_quant_config.json"
)
# Retry HF API call up to 3 times
file_exists
=
retry
(
check_hf_quant_config
,
max_retry
=
2
,
initial_delay
=
1.0
,
max_delay
=
5.0
,
)
if
file_exists
:
quant_cfg
=
modelopt_quant_config
quant_cfg
=
modelopt_quant_config
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
logger
.
warning
(
logger
.
warning
(
"Offline mode is enabled, skipping hf_quant_config.json check"
"Offline mode is enabled, skipping hf_quant_config.json check"
)
)
except
Exception
as
e
:
pass
logger
.
warning
(
f
"Failed to check hf_quant_config.json:
{
self
.
model_path
}
{
e
}
"
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
model_path
,
"hf_quant_config.json"
)):
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
model_path
,
"hf_quant_config.json"
)):
quant_config_file
=
os
.
path
.
join
(
quant_config_file
=
os
.
path
.
join
(
...
@@ -606,7 +616,7 @@ class ModelConfig:
...
@@ -606,7 +616,7 @@ class ModelConfig:
"sparse_attention_enabled"
"sparse_attention_enabled"
]
=
True
]
=
True
def
_
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
def
get_hf_eos_token_id
(
self
)
->
Optional
[
Set
[
int
]]:
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
eos_ids
=
getattr
(
self
.
hf_config
,
"eos_token_id"
,
None
)
if
eos_ids
is
not
None
:
if
eos_ids
is
not
None
:
# it can be either int or list of int
# it can be either int or list of int
...
@@ -626,7 +636,7 @@ class ModelConfig:
...
@@ -626,7 +636,7 @@ class ModelConfig:
eos_ids
=
eos_ids
|
generation_eos_ids
eos_ids
=
eos_ids
|
generation_eos_ids
return
eos_ids
return
eos_ids
def
_
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
def
maybe_pull_model_tokenizer_from_remote
(
self
)
->
None
:
"""
"""
Pull the model config files to a temporary
Pull the model config files to a temporary
directory in case of remote.
directory in case of remote.
...
@@ -769,8 +779,6 @@ multimodal_model_archs = [
...
@@ -769,8 +779,6 @@ multimodal_model_archs = [
"Qwen2AudioForConditionalGeneration"
,
"Qwen2AudioForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
"KimiVLForConditionalGeneration"
,
"KimiVLForConditionalGeneration"
,
"InternVLChatModel"
,
"InternVLChatModel"
,
"InternS1ForConditionalGeneration"
,
"InternS1ForConditionalGeneration"
,
...
...
python/sglang/srt/configs/qwen3_vl.py
deleted
100644 → 0
View file @
8f7453e3
This diff is collapsed.
Click to expand it.
python/sglang/srt/disaggregation/ascend/transfer_engine.py
View file @
852a49c5
...
@@ -2,9 +2,19 @@ import logging
...
@@ -2,9 +2,19 @@ import logging
import
os
import
os
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
try
:
from
mf_adapter
import
TransferEngine
import_error
=
None
except
ImportError
as
e
:
import_error
=
e
pass
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
...
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
def
__init__
(
def
__init__
(
self
,
hostname
:
str
,
npu_id
:
int
,
disaggregation_mode
:
DisaggregationMode
self
,
hostname
:
str
,
npu_id
:
int
,
disaggregation_mode
:
DisaggregationMode
):
):
try
:
if
import_error
is
not
None
:
from
mf_adapter
import
TransferEngine
logger
.
warning
(
except
ImportError
as
e
:
raise
ImportError
(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
)
from
e
)
raise
import_error
self
.
engine
=
TransferEngine
()
self
.
engine
=
TransferEngine
()
self
.
hostname
=
hostname
self
.
hostname
=
hostname
...
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
...
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
self
.
initialize
()
self
.
initialize
()
def
initialize
(
self
)
->
None
:
def
initialize
(
self
)
->
None
:
from
sglang.srt.layers.dp_attention
import
(
get_tensor_model_parallel_world_size
,
get_tp_group
,
)
transfer_protocol
=
self
.
_get_transfer_protocol
()
if
transfer_protocol
is
None
or
transfer_protocol
==
"sdma"
:
trans_op_type
=
TransferEngine
.
TransDataOpType
.
SDMA
else
:
trans_op_type
=
TransferEngine
.
TransDataOpType
.
DEVICE_RDMA
"""with device RDMA for PD transfer"""
tmp_tensor
=
torch
.
zeros
(
1
,
device
=
"npu"
)
output_tensor_list
=
[
torch
.
empty_like
(
tmp_tensor
)
for
_
in
range
(
get_tensor_model_parallel_world_size
())
]
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
torch
.
distributed
.
all_gather
(
output_tensor_list
,
tmp_tensor
,
group
=
get_tp_group
().
device_group
)
"""Initialize the ascend transfer instance."""
"""Initialize the ascend transfer instance."""
ret_value
=
self
.
engine
.
initialize
(
ret_value
=
self
.
engine
.
initialize
(
self
.
store_url
,
self
.
store_url
,
self
.
session_id
,
self
.
role
,
self
.
npu_id
,
trans_op_type
self
.
session_id
,
self
.
role
,
self
.
npu_id
,
)
)
if
ret_value
!=
0
:
if
ret_value
!=
0
:
logger
.
error
(
"Ascend Transfer Engine initialization failed."
)
logger
.
error
(
"Ascend Transfer Engine initialization failed."
)
...
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
...
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
ret_value
=
-
1
ret_value
=
-
1
if
ret_value
!=
0
:
if
ret_value
!=
0
:
logger
.
debug
(
f
"Ascend memory registration for ptr
{
ptrs
}
failed."
)
logger
.
debug
(
f
"Ascend memory registration for ptr
{
ptrs
}
failed."
)
@
staticmethod
def
_get_transfer_protocol
():
protocol
=
os
.
getenv
(
"ASCEND_MF_TRANSFER_PROTOCOL"
)
allowed_protocols
=
{
"device_rdma"
,
"sdma"
}
if
protocol
and
protocol
.
lower
()
in
allowed_protocols
:
return
protocol
.
lower
()
else
:
logger
.
warning
(
"Invalid or no transfer protocol specified, using default protocol."
)
return
None
\ No newline at end of file
python/sglang/srt/disaggregation/common/conn.py
View file @
852a49c5
...
@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
...
@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
def
_bind_server_socket
(
self
):
def
_bind_server_socket
(
self
):
self
.
server_socket
.
bind
(
format_tcp_address
(
self
.
local_ip
,
self
.
rank_port
))
self
.
server_socket
.
bind
(
format_tcp_address
(
self
.
local_ip
,
self
.
rank_port
))
@
cache
def
_connect
(
self
,
endpoint
:
str
,
is_ipv6
:
bool
=
False
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
if
is_ipv6
:
socket
.
setsockopt
(
zmq
.
IPV6
,
1
)
socket
.
connect
(
endpoint
)
return
socket
def
_register_to_bootstrap
(
self
):
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
"""Register KVSender to bootstrap server via HTTP POST."""
if
self
.
dist_init_addr
:
if
self
.
dist_init_addr
:
...
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
...
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
socket
.
connect
(
endpoint
)
socket
.
connect
(
endpoint
)
return
socket
return
socket
def
get_mha_kv_ptrs_with_pp
(
self
,
src_kv_ptrs
:
List
[
int
],
dst_kv_ptrs
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
int
]:
# pp is not supported on the decode side yet
start_layer
=
self
.
kv_args
.
prefill_start_layer
num_kv_layers
=
len
(
src_kv_ptrs
)
//
2
end_layer
=
start_layer
+
num_kv_layers
dst_num_total_layers
=
len
(
dst_kv_ptrs
)
//
2
src_k_ptrs
=
src_kv_ptrs
[:
num_kv_layers
]
src_v_ptrs
=
src_kv_ptrs
[
num_kv_layers
:]
dst_k_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
dst_v_ptrs
=
dst_kv_ptrs
[
dst_num_total_layers
+
start_layer
:
dst_num_total_layers
+
end_layer
]
layers_current_pp_stage
=
len
(
src_k_ptrs
)
return
src_k_ptrs
,
src_v_ptrs
,
dst_k_ptrs
,
dst_v_ptrs
,
layers_current_pp_stage
def
get_mla_kv_ptrs_with_pp
(
self
,
src_kv_ptrs
:
List
[
int
],
dst_kv_ptrs
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
int
],
int
]:
# pp is not supported on the decode side yet
start_layer
=
self
.
kv_args
.
prefill_start_layer
end_layer
=
start_layer
+
len
(
src_kv_ptrs
)
sliced_dst_kv_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
layers_current_pp_stage
=
len
(
src_kv_ptrs
)
return
src_kv_ptrs
,
sliced_dst_kv_ptrs
,
layers_current_pp_stage
class
CommonKVSender
(
BaseKVSender
):
class
CommonKVSender
(
BaseKVSender
):
...
...
python/sglang/srt/disaggregation/decode.py
View file @
852a49c5
...
@@ -609,21 +609,15 @@ class DecodeTransferQueue:
...
@@ -609,21 +609,15 @@ class DecodeTransferQueue:
idx
=
decode_req
.
metadata_buffer_index
idx
=
decode_req
.
metadata_buffer_index
(
(
output_id
,
output_id
,
cached_tokens
,
output_token_logprobs_val
,
output_token_logprobs_val
,
output_token_logprobs_idx
,
output_token_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
output_top_logprobs_idx
,
output_topk_p
,
output_topk_index
,
output_hidden_states
,
output_hidden_states
,
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
cached_tokens
=
cached_tokens
[
0
].
item
()
if
not
self
.
spec_algorithm
.
is_none
():
if
not
self
.
spec_algorithm
.
is_none
():
decode_req
.
req
.
output_topk_p
=
output_topk_p
decode_req
.
req
.
output_topk_index
=
output_topk_index
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
if
decode_req
.
req
.
return_logprob
:
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
decode_req
.
req
.
output_token_logprobs_val
.
append
(
...
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
elif
prepare_mlp_sync_flag
:
elif
prepare_mlp_sync_flag
:
batch
,
_
=
self
.
_prepare_idle_batch_and_run
(
None
)
batch
,
_
=
self
.
_prepare_idle_batch_and_run
(
None
)
queue_size
=
(
if
batch
is
None
and
(
len
(
self
.
waiting_queue
)
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
)
==
0
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
):
queue_size
+=
len
(
self
.
decode_offload_manager
.
ongoing_offload
)
if
batch
is
None
and
queue_size
==
0
:
self
.
self_check_during_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
)
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
queue_size
=
(
if
batch
is
None
and
(
len
(
self
.
waiting_queue
)
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
)
==
0
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
):
queue_size
+=
len
(
self
.
decode_offload_manager
.
ongoing_offload
)
if
batch
is
None
and
queue_size
==
0
:
self
.
self_check_during_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
self
.
disagg_decode_transfer_queue
.
pop_transferred
()
self
.
disagg_decode_transfer_queue
.
pop_transferred
()
)
# the requests which kv has arrived
)
# the requests which kv has arrived
self
.
waiting_queue
.
extend
(
alloc_reqs
)
self
.
waiting_queue
.
extend
(
alloc_reqs
)
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
self
.
decode_offload_manager
.
check_offload_progress
()
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
deleted
100644 → 0
View file @
8f7453e3
import
logging
import
threading
import
time
import
torch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
)
from
sglang.srt.mem_cache.memory_pool_host
import
(
MHATokenToKVPoolHost
,
MLATokenToKVPoolHost
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
DecodeKVCacheOffloadManager
:
"""Manage decode-side KV cache offloading lifecycle and operations."""
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
BaseTokenToKVPoolAllocator
,
tp_group
:
torch
.
distributed
.
ProcessGroup
,
tree_cache
:
BasePrefixCache
,
server_args
:
ServerArgs
,
)
->
None
:
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
page_size
=
server_args
.
page_size
self
.
server_args
=
server_args
self
.
request_counter
=
0
self
.
tree_cache
=
tree_cache
kv_cache
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
kv_cache
,
MHATokenToKVPool
):
self
.
decode_host_mem_pool
=
MHATokenToKVPoolHost
(
kv_cache
,
server_args
.
hicache_ratio
,
server_args
.
hicache_size
,
self
.
page_size
,
server_args
.
hicache_mem_layout
,
)
elif
isinstance
(
kv_cache
,
MLATokenToKVPool
):
self
.
decode_host_mem_pool
=
MLATokenToKVPoolHost
(
kv_cache
,
server_args
.
hicache_ratio
,
server_args
.
hicache_size
,
self
.
page_size
,
server_args
.
hicache_mem_layout
,
)
else
:
raise
ValueError
(
"Unsupported KV cache type for decode offload"
)
self
.
tp_group
=
tp_group
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
mem_pool_host
=
self
.
decode_host_mem_pool
,
page_size
=
self
.
page_size
,
tp_group
=
tp_group
,
io_backend
=
server_args
.
hicache_io_backend
,
load_cache_event
=
threading
.
Event
(),
storage_backend
=
server_args
.
hicache_storage_backend
,
model_name
=
server_args
.
served_model_name
,
storage_backend_extra_config
=
server_args
.
hicache_storage_backend_extra_config
,
)
self
.
ongoing_offload
=
{}
self
.
ongoing_backup
=
{}
logger
.
info
(
"Enable offload kv cache for decode side"
)
def
offload_kv_cache
(
self
,
req
)
->
bool
:
"""Offload a finished request's KV cache to storage."""
if
self
.
cache_controller
is
None
or
self
.
decode_host_mem_pool
is
None
:
return
False
if
req
.
req_pool_idx
==
-
1
:
return
False
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
]
if
token_indices
.
dim
()
==
0
or
token_indices
.
numel
()
==
0
:
logger
.
debug
(
f
"Request
{
req
.
rid
}
has invalid token_indices:
{
token_indices
}
"
)
return
False
tokens
=
req
.
origin_input_ids
+
req
.
output_ids
aligned_len
=
(
len
(
tokens
)
//
self
.
page_size
)
*
self
.
page_size
if
aligned_len
==
0
:
return
False
token_indices
=
token_indices
[:
aligned_len
]
tokens
=
tokens
[:
aligned_len
]
# Asynchronously offload KV cache from device to host by cache controller
self
.
request_counter
+=
1
ack_id
=
self
.
request_counter
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
token_indices
.
long
(),
node_id
=
ack_id
,
)
if
host_indices
is
None
:
logger
.
error
(
f
"Not enough host memory for request
{
req
.
rid
}
"
)
return
False
self
.
ongoing_offload
[
ack_id
]
=
(
req
,
host_indices
,
tokens
,
time
.
time
())
return
True
def
check_offload_progress
(
self
):
"""Check the progress of offload from device to host and backup from host to storage."""
cc
=
self
.
cache_controller
qsizes
=
torch
.
tensor
(
[
len
(
cc
.
ack_write_queue
),
cc
.
ack_backup_queue
.
qsize
(),
],
dtype
=
torch
.
int
,
)
if
self
.
tp_world_size
>
1
:
torch
.
distributed
.
all_reduce
(
qsizes
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
)
n_write
,
n_backup
=
map
(
int
,
qsizes
.
tolist
())
self
.
_check_offload_progress
(
n_write
)
self
.
_check_backup_progress
(
n_backup
)
def
_check_offload_progress
(
self
,
finish_count
):
"""Check the progress of offload from device to host."""
while
finish_count
>
0
:
_
,
finish_event
,
ack_list
=
self
.
cache_controller
.
ack_write_queue
.
pop
(
0
)
finish_event
.
synchronize
()
for
ack_id
in
ack_list
:
req
,
host_indices
,
tokens
,
start_time
=
self
.
ongoing_offload
.
pop
(
ack_id
)
# Release device
self
.
tree_cache
.
cache_finished_req
(
req
)
# Trigger async backup from host to storage by cache controller
self
.
_trigger_backup
(
req
.
rid
,
host_indices
,
tokens
,
start_time
)
finish_count
-=
1
def
_check_backup_progress
(
self
,
finish_count
):
"""Check the progress of backup from host to storage."""
for
_
in
range
(
finish_count
):
storage_operation
=
self
.
cache_controller
.
ack_backup_queue
.
get
()
ack_id
=
storage_operation
.
id
req_id
,
host_indices
,
start_time
=
self
.
ongoing_backup
.
pop
(
ack_id
)
# Release host memory
self
.
decode_host_mem_pool
.
free
(
host_indices
)
logger
.
debug
(
f
"Finished backup request
{
req_id
}
, free host memory, len:
{
len
(
host_indices
)
}
, cost time:
{
time
.
time
()
-
start_time
:.
2
f
}
seconds."
)
def
_trigger_backup
(
self
,
req_id
,
host_indices
,
tokens
,
start_time
):
"""Trigger async backup from host to storage by cache controller."""
# Generate page hashes and write to storage
page_hashes
=
self
.
_compute_prefix_hash
(
tokens
)
ack_id
=
self
.
cache_controller
.
write_storage
(
host_indices
,
tokens
,
hash_value
=
page_hashes
,
)
self
.
ongoing_backup
[
ack_id
]
=
(
req_id
,
host_indices
,
start_time
)
def
_compute_prefix_hash
(
self
,
tokens
):
last_hash
=
""
page_hashes
=
[]
for
offset
in
range
(
0
,
len
(
tokens
),
self
.
page_size
):
page_tokens
=
tokens
[
offset
:
offset
+
self
.
page_size
]
last_hash
=
self
.
cache_controller
.
get_hash_str
(
page_tokens
,
last_hash
)
page_hashes
.
append
(
last_hash
)
return
page_hashes
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
852a49c5
...
@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
req
.
grammar
.
finished
=
req
.
finished
()
req
.
grammar
.
finished
=
req
.
finished
()
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
# Simulate the eagle run.
# Simulate the eagle run. We add mock data to hidden states for the
if
self
.
spec_algorithm
.
is_eagle
():
# ease of implementation now meaning the first token will have acc rate
# of 0.
if
not
self
.
spec_algorithm
.
is_none
():
b
=
len
(
self
.
reqs
)
b
=
len
(
self
.
reqs
)
topk
=
server_args
.
speculative_eagle_topk
topk_p
=
torch
.
arange
(
topk_p
=
torch
.
stack
(
b
*
server_args
.
speculative_eagle_topk
,
[
0
,
torch
.
as_tensor
(
-
1
,
req
.
output_topk_p
[:
topk
],
device
=
self
.
device
,
device
=
self
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
)
topk_index
=
torch
.
stack
(
topk_p
=
topk_p
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
[
topk_p
/=
b
*
server_args
.
speculative_eagle_topk
torch
.
as_tensor
(
topk_index
=
torch
.
arange
(
req
.
output_topk_index
[:
topk
],
b
*
server_args
.
speculative_eagle_topk
,
device
=
self
.
device
device
=
self
.
device
,
dtype
=
torch
.
int64
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
852a49c5
...
@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
layers_params
=
None
layers_params
=
None
# pp is not supported on the decode side yet
# pp is not supported on the decode side yet
start_layer
=
self
.
kv_args
.
prefill_start_layer
end_layer
=
start_layer
+
len
(
self
.
kv_args
.
kv_data_ptrs
)
if
self
.
is_mla_backend
:
if
self
.
is_mla_backend
:
src_kv_ptrs
,
dst_kv_ptrs
,
layers_current_pp_stage
=
(
src_kv_ptrs
=
self
.
kv_args
.
kv_data_ptrs
self
.
get_mla_kv_ptrs_with_pp
(
self
.
kv_args
.
kv_data_ptrs
,
dst
_kv_ptrs
)
layers_per_pp_stage
=
len
(
src
_kv_ptrs
)
)
dst_kv_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
layers_params
=
[
layers_params
=
[
(
(
...
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
dst_kv_ptrs
[
layer_id
],
dst_kv_ptrs
[
layer_id
],
kv_item_len
,
kv_item_len
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
]
else
:
else
:
src_k_ptrs
,
src_v_ptrs
,
dst_k_ptrs
,
dst_v_ptrs
,
layers_current_pp_stage
=
(
num_kv_layers
=
len
(
self
.
kv_args
.
kv_data_ptrs
)
//
2
self
.
get_mha_kv_ptrs_with_pp
(
self
.
kv_args
.
kv_data_ptrs
,
dst_kv_ptrs
)
dst_num_total_layers
=
num_kv_layers
*
self
.
pp_size
)
src_k_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[:
num_kv_layers
]
src_v_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[
num_kv_layers
:]
layers_per_pp_stage
=
len
(
src_k_ptrs
)
dst_k_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
dst_v_ptrs
=
dst_kv_ptrs
[
dst_num_total_layers
+
start_layer
:
dst_num_total_layers
+
end_layer
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
layers_params
=
[
layers_params
=
[
(
(
...
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
dst_k_ptrs
[
layer_id
],
dst_k_ptrs
[
layer_id
],
kv_item_len
,
kv_item_len
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
+
[
]
+
[
(
(
src_v_ptrs
[
layer_id
],
src_v_ptrs
[
layer_id
],
dst_v_ptrs
[
layer_id
],
dst_v_ptrs
[
layer_id
],
kv_item_len
,
kv_item_len
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
]
assert
layers_params
is
not
None
assert
layers_params
is
not
None
...
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
num_heads_to_send
=
dst_heads_per_rank
num_heads_to_send
=
dst_heads_per_rank
dst_head_start_offset
=
0
dst_head_start_offset
=
0
src_k_ptrs
,
src_v_ptrs
,
dst_k_ptrs
,
dst_v_ptrs
,
layers_current_pp_stage
=
(
# pp is not supported on the decode side yet
self
.
get_mha_kv_ptrs_with_pp
(
self
.
kv_args
.
kv_data_ptrs
,
dst_kv_ptrs
)
num_kv_layers
=
len
(
self
.
kv_args
.
kv_data_ptrs
)
//
2
)
dst_num_total_layers
=
num_kv_layers
*
self
.
pp_size
src_k_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[:
num_kv_layers
]
src_v_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[
num_kv_layers
:]
layers_per_pp_stage
=
len
(
src_k_ptrs
)
start_layer
=
self
.
pp_rank
*
layers_per_pp_stage
end_layer
=
start_layer
+
layers_per_pp_stage
dst_k_ptrs
=
dst_kv_ptrs
[
start_layer
:
end_layer
]
dst_v_ptrs
=
dst_kv_ptrs
[
dst_num_total_layers
+
start_layer
:
dst_num_total_layers
+
end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset
=
src_head_start_offset
*
bytes_per_head_slice_to_send
src_head_slice_offset
=
src_head_start_offset
*
bytes_per_head_slice_to_send
...
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset
,
dst_head_slice_offset
,
heads_bytes_per_token_to_send
,
heads_bytes_per_token_to_send
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
+
[
]
+
[
(
(
src_v_ptrs
[
layer_id
],
src_v_ptrs
[
layer_id
],
...
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
...
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset
,
dst_head_slice_offset
,
heads_bytes_per_token_to_send
,
heads_bytes_per_token_to_send
,
)
)
for
layer_id
in
range
(
layers_
current
_pp_stage
)
for
layer_id
in
range
(
layers_
per
_pp_stage
)
]
]
def
process_layer_tp_aware
(
layer_params
):
def
process_layer_tp_aware
(
layer_params
):
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
852a49c5
...
@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index
=
(
last_hidden_index
=
(
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
)
)
req
.
output_topk_p
=
batch
.
spec_info
.
topk_p
[
i
]
req
.
output_topk_index
=
batch
.
spec_info
.
topk_index
[
i
]
if
self
.
spec_algorithm
.
is_eagle3
():
if
self
.
spec_algorithm
.
is_eagle3
():
req
.
hidden_states_tensor
=
(
req
.
hidden_states_tensor
=
(
batch
.
spec_info
.
hidden_states
[
i
].
cpu
().
clone
()
batch
.
spec_info
.
hidden_states
[
i
].
cpu
().
clone
()
...
...
python/sglang/srt/disaggregation/utils.py
View file @
852a49c5
...
@@ -85,7 +85,7 @@ class MetadataBuffers:
...
@@ -85,7 +85,7 @@ class MetadataBuffers:
self
,
self
,
size
:
int
,
size
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
hidden_states_
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
max_top_logprobs_num
:
int
=
128
,
max_top_logprobs_num
:
int
=
128
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
):
):
...
@@ -107,9 +107,7 @@ class MetadataBuffers:
...
@@ -107,9 +107,7 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cached_tokens
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
output_token_logprobs_val
=
torch
.
zeros
(
self
.
output_token_logprobs_val
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
)
)
...
@@ -122,49 +120,33 @@ class MetadataBuffers:
...
@@ -122,49 +120,33 @@ class MetadataBuffers:
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
device
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
device
)
)
# For PD + spec decode
self
.
output_topk_p
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
output_topk_index
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
output_hidden_states
=
torch
.
zeros
(
self
.
output_hidden_states
=
torch
.
zeros
(
(
size
,
hidden_size
),
dtype
=
hidden_states_
dtype
,
device
=
device
(
size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
)
)
def
get_buf_infos
(
self
):
def
get_buf_infos
(
self
):
ptrs
=
[
ptrs
=
[
self
.
output_ids
.
data_ptr
(),
self
.
output_ids
.
data_ptr
(),
self
.
cached_tokens
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
self
.
output_topk_p
.
data_ptr
(),
self
.
output_topk_index
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
]
]
data_lens
=
[
data_lens
=
[
self
.
output_ids
.
nbytes
,
self
.
output_ids
.
nbytes
,
self
.
cached_tokens
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
self
.
output_topk_p
.
nbytes
,
self
.
output_topk_index
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
]
]
item_lens
=
[
item_lens
=
[
self
.
output_ids
[
0
].
nbytes
,
self
.
output_ids
[
0
].
nbytes
,
self
.
cached_tokens
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
self
.
output_topk_p
[
0
].
nbytes
,
self
.
output_topk_index
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
]
]
return
ptrs
,
data_lens
,
item_lens
return
ptrs
,
data_lens
,
item_lens
...
@@ -172,20 +154,16 @@ class MetadataBuffers:
...
@@ -172,20 +154,16 @@ class MetadataBuffers:
def
get_buf
(
self
,
idx
:
int
):
def
get_buf
(
self
,
idx
:
int
):
return
(
return
(
self
.
output_ids
[
idx
],
self
.
output_ids
[
idx
],
self
.
cached_tokens
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
self
.
output_topk_p
[
idx
],
self
.
output_topk_index
[
idx
],
self
.
output_hidden_states
[
idx
],
self
.
output_hidden_states
[
idx
],
)
)
def
set_buf
(
self
,
req
:
Req
):
def
set_buf
(
self
,
req
:
Req
):
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
self
.
cached_tokens
[
req
.
metadata_buffer_index
][
0
]
=
req
.
cached_tokens
if
req
.
return_logprob
:
if
req
.
return_logprob
:
if
req
.
output_token_logprobs_val
:
# not none or empty list
if
req
.
output_token_logprobs_val
:
# not none or empty list
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
...
@@ -208,17 +186,8 @@ class MetadataBuffers:
...
@@ -208,17 +186,8 @@ class MetadataBuffers:
]
=
torch
.
tensor
(
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
)
#
F
or PD + spec decode
#
f
or PD + spec decode
if
req
.
hidden_states_tensor
is
not
None
:
if
req
.
hidden_states_tensor
is
not
None
:
# speculative_eagle_topk should not be greater than 16 currently
topk
=
req
.
output_topk_p
.
size
(
0
)
self
.
output_topk_p
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_p
)
self
.
output_topk_index
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_index
)
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
req
.
hidden_states_tensor
req
.
hidden_states_tensor
)
)
...
...
python/sglang/srt/entrypoints/engine.py
View file @
852a49c5
...
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
_is_cuda
and
not
get_bool_env_var
(
"SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"
):
if
_is_cuda
and
not
get_bool_env_var
(
"SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"
):
assert_pkg_version
(
assert_pkg_version
(
"sgl-kernel"
,
"sgl-kernel"
,
"0.3.1
2
"
,
"0.3.1
1
"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
)
)
...
...
Prev
1
2
3
4
5
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment