Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
22352d47
Unverified
Commit
22352d47
authored
Jun 29, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 29, 2025
Browse files
Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)
Co-authored-by:
Kan Wu
<
wukanustc@gmail.com
>
parent
c5131f7a
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
461 additions
and
155 deletions
+461
-155
docs/backend/server_arguments.md
docs/backend/server_arguments.md
+1
-1
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+14
-2
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+0
-1
python/sglang/srt/configs/internvl.py
python/sglang/srt/configs/internvl.py
+2
-3
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+23
-6
python/sglang/srt/layers/elementwise.py
python/sglang/srt/layers/elementwise.py
+76
-12
python/sglang/srt/layers/moe/router.py
python/sglang/srt/layers/moe/router.py
+60
-22
python/sglang/srt/managers/configure_logging.py
python/sglang/srt/managers/configure_logging.py
+1
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-3
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+52
-0
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-51
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+36
-20
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+8
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+124
-25
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-2
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+38
-0
python/sglang/srt/models/mistral.py
python/sglang/srt/models/mistral.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-2
No files found.
docs/backend/server_arguments.md
View file @
22352d47
...
@@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
...
@@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`--log-level`
| The logging level of all loggers. | info |
|
`--log-level`
| The logging level of all loggers. | info |
|
`--log-level-http`
| The logging level of HTTP server. If not set, reuse --log-level by default. | None |
|
`--log-level-http`
| The logging level of HTTP server. If not set, reuse --log-level by default. | None |
|
`--log-requests`
| Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
|
`--log-requests`
| Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
|
`--log-requests-level`
| 0: Log metadata
. 1. Log metadata
and partial input/output.
2.
Log every input/output. | 0 |
|
`--log-requests-level`
| 0: Log metadata
(no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters
and partial input/output.
3:
Log every input/output. | 0 |
|
`--show-time-cost`
| Show time cost of custom marks. | False |
|
`--show-time-cost`
| Show time cost of custom marks. | False |
|
`--enable-metrics`
| Enable log prometheus metrics. | False |
|
`--enable-metrics`
| Enable log prometheus metrics. | False |
|
`--bucket-time-to-first-token`
| The buckets of time to first token, specified as a list of floats. | None |
|
`--bucket-time-to-first-token`
| The buckets of time to first token, specified as a list of floats. | None |
...
...
python/sglang/bench_one_batch_server.py
View file @
22352d47
...
@@ -38,6 +38,7 @@ class BenchArgs:
...
@@ -38,6 +38,7 @@ class BenchArgs:
output_len
:
Tuple
[
int
]
=
(
16
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
temperature
:
float
=
0.0
temperature
:
float
=
0.0
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
client_stream_interval
:
int
=
1
input_len_step_percentage
:
float
=
0.0
input_len_step_percentage
:
float
=
0.0
result_filename
:
str
=
"result.jsonl"
result_filename
:
str
=
"result.jsonl"
base_url
:
str
=
""
base_url
:
str
=
""
...
@@ -60,6 +61,11 @@ class BenchArgs:
...
@@ -60,6 +61,11 @@ class BenchArgs:
)
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--client-stream-interval"
,
type
=
int
,
default
=
BenchArgs
.
client_stream_interval
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--input-len-step-percentage"
,
"--input-len-step-percentage"
,
type
=
float
,
type
=
float
,
...
@@ -120,6 +126,7 @@ def run_one_case(
...
@@ -120,6 +126,7 @@ def run_one_case(
output_len
:
int
,
output_len
:
int
,
temperature
:
float
,
temperature
:
float
,
return_logprob
:
bool
,
return_logprob
:
bool
,
stream_interval
:
int
,
input_len_step_percentage
:
float
,
input_len_step_percentage
:
float
,
run_name
:
str
,
run_name
:
str
,
result_filename
:
str
,
result_filename
:
str
,
...
@@ -168,6 +175,7 @@ def run_one_case(
...
@@ -168,6 +175,7 @@ def run_one_case(
"max_new_tokens"
:
output_len
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
"json_schema"
:
json_schema
,
"stream_interval"
:
stream_interval
,
},
},
"return_logprob"
:
return_logprob
,
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
"stream"
:
True
,
...
@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
else
:
else
:
proc
,
base_url
=
launch_server_process
(
server_args
)
proc
,
base_url
=
launch_server_process
(
server_args
)
tokenizer_id
=
server_args
.
tokenizer_path
or
server_args
.
model_path
server_info
=
requests
.
get
(
base_url
+
"/get_server_info"
)
tokenizer
=
get_tokenizer
(
tokenizer_id
)
tokenizer_path
=
server_info
.
json
()[
"tokenizer_path"
]
tokenizer
=
get_tokenizer
(
tokenizer_path
)
# warmup
# warmup
if
not
bench_args
.
skip_warmup
:
if
not
bench_args
.
skip_warmup
:
...
@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
output_len
=
16
,
output_len
=
16
,
temperature
=
bench_args
.
temperature
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
""
,
run_name
=
""
,
result_filename
=
""
,
result_filename
=
""
,
...
@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol
,
ol
,
temperature
=
bench_args
.
temperature
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
...
@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol
,
ol
,
temperature
=
bench_args
.
temperature
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
result_filename
=
bench_args
.
result_filename
,
...
...
python/sglang/bench_serving.py
View file @
22352d47
...
@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
...
@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
if
args
.
base_url
if
args
.
base_url
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
)
args
.
apply_chat_template
=
True
elif
args
.
backend
in
[
"sglang-oai"
,
"vllm"
,
"lmdeploy"
]:
elif
args
.
backend
in
[
"sglang-oai"
,
"vllm"
,
"lmdeploy"
]:
api_url
=
(
api_url
=
(
f
"
{
args
.
base_url
}
/v1/completions"
f
"
{
args
.
base_url
}
/v1/completions"
...
...
python/sglang/srt/configs/internvl.py
View file @
22352d47
...
@@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig):
...
@@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig):
)
)
if
(
if
(
rope_scaling_factor
is
None
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
not
isinstance
(
rope_scaling_factor
,
(
float
,
int
))
or
not
isinstance
(
rope_scaling_factor
,
int
)
or
rope_scaling_factor
<
1.0
or
rope_scaling_factor
<
1.0
):
):
raise
ValueError
(
raise
ValueError
(
f
"`rope_scaling`'s factor field must be a float|int >= 1, got
{
rope_scaling_factor
}
"
f
"`rope_scaling`'s factor field must be a float|int >= 1, got
{
rope_scaling_factor
=
}
,
{
type
(
rope_scaling_factor
)
=
}
"
)
)
if
isinstance
(
rope_scaling_factor
,
int
):
if
isinstance
(
rope_scaling_factor
,
int
):
rope_scaling_factor
=
float
(
rope_scaling_factor
)
rope_scaling_factor
=
float
(
rope_scaling_factor
)
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
22352d47
...
@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
...
@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
@
asynccontextmanager
@
asynccontextmanager
async
def
lifespan
(
fast_api_app
:
FastAPI
):
async
def
lifespan
(
fast_api_app
:
FastAPI
):
server_args
:
ServerArgs
=
fast_api_app
.
server_args
# Initialize OpenAI serving handlers
# Initialize OpenAI serving handlers
fast_api_app
.
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
fast_api_app
.
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
_global_state
.
tokenizer_manager
,
_global_state
.
template_manager
_global_state
.
tokenizer_manager
,
_global_state
.
template_manager
...
@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
...
@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
_global_state
.
tokenizer_manager
_global_state
.
tokenizer_manager
)
)
server_args
:
ServerArgs
=
fast_api_app
.
server_args
if
server_args
.
warmups
is
not
None
:
if
server_args
.
warmups
is
not
None
:
await
execute_warmups
(
await
execute_warmups
(
server_args
.
warmups
.
split
(
","
),
_global_state
.
tokenizer_manager
server_args
.
disaggregation_mode
,
server_args
.
warmups
.
split
(
","
),
_global_state
.
tokenizer_manager
,
)
)
logger
.
info
(
"Warmup ended"
)
logger
.
info
(
"Warmup ended"
)
...
@@ -280,13 +281,17 @@ async def get_model_info():
...
@@ -280,13 +281,17 @@ async def get_model_info():
"model_path"
:
_global_state
.
tokenizer_manager
.
model_path
,
"model_path"
:
_global_state
.
tokenizer_manager
.
model_path
,
"tokenizer_path"
:
_global_state
.
tokenizer_manager
.
server_args
.
tokenizer_path
,
"tokenizer_path"
:
_global_state
.
tokenizer_manager
.
server_args
.
tokenizer_path
,
"is_generation"
:
_global_state
.
tokenizer_manager
.
is_generation
,
"is_generation"
:
_global_state
.
tokenizer_manager
.
is_generation
,
"preferred_sampling_params"
:
_global_state
.
tokenizer_manager
.
server_args
.
preferred_sampling_params
,
}
}
return
result
return
result
@
app
.
get
(
"/get_server_info"
)
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_info
():
async
def
get_server_info
():
internal_states
=
await
_global_state
.
tokenizer_manager
.
get_internal_state
()
# Returns interna states per DP.
internal_states
:
List
[
Dict
[
Any
,
Any
]]
=
(
await
_global_state
.
tokenizer_manager
.
get_internal_state
()
)
return
{
return
{
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
_global_state
.
scheduler_info
,
**
_global_state
.
scheduler_info
,
...
@@ -300,6 +305,8 @@ async def get_load():
...
@@ -300,6 +305,8 @@ async def get_load():
return
await
_global_state
.
tokenizer_manager
.
get_load
()
return
await
_global_state
.
tokenizer_manager
.
get_load
()
# example usage:
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
@
app
.
api_route
(
"/set_internal_state"
,
methods
=
[
"POST"
,
"PUT"
])
@
app
.
api_route
(
"/set_internal_state"
,
methods
=
[
"POST"
,
"PUT"
])
async
def
set_internal_state
(
obj
:
SetInternalStateReq
,
request
:
Request
):
async
def
set_internal_state
(
obj
:
SetInternalStateReq
,
request
:
Request
):
res
=
await
_global_state
.
tokenizer_manager
.
set_internal_state
(
obj
)
res
=
await
_global_state
.
tokenizer_manager
.
set_internal_state
(
obj
)
...
@@ -886,6 +893,15 @@ def launch_server(
...
@@ -886,6 +893,15 @@ def launch_server(
add_prometheus_middleware
(
app
)
add_prometheus_middleware
(
app
)
enable_func_timer
()
enable_func_timer
()
image_token_text
=
None
if
(
tokenizer_manager
.
image_token_id
is
not
None
and
not
server_args
.
skip_tokenizer_init
):
image_token_text
=
tokenizer_manager
.
tokenizer
.
decode
(
[
tokenizer_manager
.
image_token_id
]
)
# Send a warmup request - we will create the thread launch it
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
# in the lifespan after all other warmups have fired.
warmup_thread
=
threading
.
Thread
(
warmup_thread
=
threading
.
Thread
(
...
@@ -893,7 +909,7 @@ def launch_server(
...
@@ -893,7 +909,7 @@ def launch_server(
args
=
(
args
=
(
server_args
,
server_args
,
pipe_finish_writer
,
pipe_finish_writer
,
_global_state
.
tokenizer_manager
.
image_token_
id
,
image_token_
text
,
launch_callback
,
launch_callback
,
),
),
)
)
...
@@ -1022,9 +1038,10 @@ def _wait_and_warmup(
...
@@ -1022,9 +1038,10 @@ def _wait_and_warmup(
return
return
# Debug print
# Debug print
# logger.info(f"{res.json()=}")
# logger.info(f"
warmup request returns:
{res.json()=}")
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"ready"
)
pipe_finish_writer
.
send
(
"ready"
)
...
...
python/sglang/srt/layers/elementwise.py
View file @
22352d47
...
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
...
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fused_softcap_autotune
=
triton
.
autotune
(
fused_softcap_autotune
=
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
(
kwargs
=
{
"BLOCK_SIZE"
:
128
},
num_warps
=
4
),
triton
.
Config
(
kwargs
=
{
"BLOCK_SIZE"
:
128
},
num_warps
=
4
),
...
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
...
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert
x
.
shape
==
residual
.
shape
and
x
.
dtype
==
residual
.
dtype
assert
x
.
shape
==
residual
.
shape
and
x
.
dtype
==
residual
.
dtype
output
,
mid
=
torch
.
empty_like
(
x
),
torch
.
empty_like
(
x
)
output
,
mid
=
torch
.
empty_like
(
x
),
torch
.
empty_like
(
x
)
bs
,
hidden_dim
=
x
.
shape
bs
,
hidden_dim
=
x
.
shape
min_num_warps
=
16
if
_is_hip
else
32
if
autotune
:
if
autotune
:
fused_dual_residual_rmsnorm_kernel_autotune
[(
bs
,)](
fused_dual_residual_rmsnorm_kernel_autotune
[(
bs
,)](
output
,
mid
,
x
,
residual
,
weight1
,
weight2
,
eps
=
eps
,
hidden_dim
=
hidden_dim
output
,
mid
,
x
,
residual
,
weight1
,
weight2
,
eps
=
eps
,
hidden_dim
=
hidden_dim
)
)
else
:
else
:
max_warps
=
16
if
_is_hip
else
32
config
=
{
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
"num_warps"
:
max
(
min
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
max_warps
),
4
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
min_num_warps
),
4
,
),
),
}
}
...
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
...
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else
:
else
:
output
=
torch
.
empty_like
(
x
)
output
=
torch
.
empty_like
(
x
)
bs
,
hidden_dim
=
x
.
shape
bs
,
hidden_dim
=
x
.
shape
max_warps
=
16
if
_is_hip
else
32
min_num_warps
=
16
if
_is_hip
else
32
config
=
{
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
in_num
_warps
),
4
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
ax
_warps
),
4
),
),
}
}
...
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
...
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
return
self
.
rmsnorm2
.
forward_native
(
residual
),
residual
return
self
.
rmsnorm2
.
forward_native
(
residual
),
residual
@
triton
.
jit
def
experts_combine_kernel
(
out_hidden_states
,
moe_hidden_states
,
mlp_hidden_states
,
combine_k
:
tl
.
constexpr
,
hidden_dim
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_index_mlp
=
pid
*
hidden_dim
start_index_rmoe
=
pid
*
hidden_dim
*
combine_k
offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
hidden_dim
combine_k_offsets
=
tl
.
arange
(
0
,
combine_k
)
moe_x
=
tl
.
load
(
moe_hidden_states
+
start_index_rmoe
+
combine_k_offsets
[:,
None
]
*
hidden_dim
+
offsets
[
None
,
:],
mask
=
mask
[
None
,
:],
other
=
0.0
,
)
moe_x
=
tl
.
sum
(
moe_x
,
axis
=
0
)
mlp_x
=
tl
.
load
(
mlp_hidden_states
+
start_index_mlp
+
offsets
,
mask
=
mask
,
other
=
0.0
)
combined_x
=
(
moe_x
+
mlp_x
)
/
1.4142135623730951
tl
.
store
(
out_hidden_states
+
start_index_mlp
+
offsets
,
combined_x
,
mask
=
mask
)
def
experts_combine_triton
(
moe_hidden_states
,
mlp_hidden_states
,
output_buffer
=
None
):
assert
moe_hidden_states
.
is_contiguous
()
assert
mlp_hidden_states
.
is_contiguous
()
if
len
(
moe_hidden_states
.
shape
)
==
2
:
combine_k
=
1
# pre-combined
else
:
combine_k
=
moe_hidden_states
.
shape
[
1
]
if
output_buffer
is
None
:
out_hidden_states
=
torch
.
empty_like
(
mlp_hidden_states
)
else
:
flat_output_buffer
=
output_buffer
.
view
(
mlp_hidden_states
.
dtype
).
reshape
(
-
1
)
assert
flat_output_buffer
.
numel
()
>=
mlp_hidden_states
.
numel
()
out_hidden_states
=
flat_output_buffer
[:
mlp_hidden_states
.
numel
()].
reshape
(
mlp_hidden_states
.
shape
)
bs
,
hidden_dim
=
mlp_hidden_states
.
shape
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
1024
)),
8
),
4
),
}
experts_combine_kernel
[(
bs
,)](
out_hidden_states
,
moe_hidden_states
,
mlp_hidden_states
,
combine_k
,
hidden_dim
,
**
config
,
)
return
out_hidden_states
# gelu on first half of vector
# gelu on first half of vector
@
triton
.
jit
@
triton
.
jit
def
gelu_and_mul_kernel
(
def
gelu_and_mul_kernel
(
...
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
...
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
out_scales
=
scales
out_scales
=
scales
static_scale
=
True
static_scale
=
True
max_warps
=
16
if
_is_hip
else
32
config
=
{
config
=
{
# 8 ele per thread (not tuned)
# 8 ele per thread (not tuned)
"num_warps"
:
max
(
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
8
*
32
)),
32
),
4
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
8
*
32
)),
max_warps
),
4
),
),
}
}
...
...
python/sglang/srt/layers/moe/router.py
View file @
22352d47
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
import
triton
import
triton
...
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
...
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
moe_router_weight_ptr
,
# input (num_experts, hidden_dim)
moe_router_weight_ptr
,
# input (num_experts, hidden_dim)
topk_weights_ptr
,
# output (bs, topk)
topk_weights_ptr
,
# output (bs, topk)
topk_ids_ptr
,
# output (bs, topk)
topk_ids_ptr
,
# output (bs, topk)
correction_bias_ptr
,
is_correction_bias
:
tl
.
constexpr
,
num_experts
:
tl
.
constexpr
,
num_experts
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
moe_softcapping
:
tl
.
constexpr
,
moe_softcapping
:
tl
.
constexpr
,
...
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
...
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
bottom
=
exped
+
1
bottom
=
exped
+
1
logits_softcapped
=
top
/
bottom
*
moe_softcapping
logits_softcapped
=
top
/
bottom
*
moe_softcapping
# Add bias after softcapping
if
is_correction_bias
:
bias
=
tl
.
load
(
correction_bias_ptr
+
tl
.
arange
(
0
,
num_experts
))
logits_softcapped
=
logits_softcapped
+
bias
# topk
# topk
# assert 1 <= topk <= num_experts
# assert 1 <= topk <= num_experts
...
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
...
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
router_weight
:
torch
.
Tensor
,
router_weight
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
moe_softcapping
:
float
,
moe_softcapping
:
float
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
assert
len
(
x
.
shape
)
==
2
and
x
.
shape
[
1
]
==
router_weight
.
shape
[
1
]
assert
len
(
x
.
shape
)
==
2
and
x
.
shape
[
1
]
==
router_weight
.
shape
[
1
]
bs
,
hidden_dim
=
x
.
shape
bs
,
hidden_dim
=
x
.
shape
...
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
...
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
topk_weights
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
topk_weights
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
topk_ids
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
int32
,
device
=
x
.
device
)
topk_ids
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
int32
,
device
=
x
.
device
)
is_correction_bias
=
correction_bias
is
not
None
grid
=
lambda
meta
:
(
bs
,)
max_warps
=
16
if
_is_hip
else
32
min_num_warps
=
16
if
_is_hip
else
32
config
=
{
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
in_num
_warps
),
4
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
ax
_warps
),
4
),
),
}
}
fused_moe_router_kernel
[
grid
](
fused_moe_router_kernel
[
(
bs
,)
](
x
,
x
,
router_weight
,
router_weight
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
correction_bias
,
is_correction_bias
=
is_correction_bias
,
num_experts
=
num_experts
,
num_experts
=
num_experts
,
topk
=
topk
,
topk
=
topk
,
moe_softcapping
=
moe_softcapping
,
moe_softcapping
=
moe_softcapping
,
...
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
...
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
topk_ids_ptr
,
# output (bs, topk)
topk_ids_ptr
,
# output (bs, topk)
bs
,
bs
,
num_experts
:
tl
.
constexpr
,
num_experts
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
# only support topk
=
=
1
topk
:
tl
.
constexpr
,
# only support topk
<
=
2
moe_softcapping
:
tl
.
constexpr
,
moe_softcapping
:
tl
.
constexpr
,
moe_renormalize
:
tl
.
constexpr
,
# not supported
moe_renormalize
:
tl
.
constexpr
,
# not supported
K
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
...
@@ -204,23 +212,51 @@ def fused_moe_router_large_bs_kernel(
...
@@ -204,23 +212,51 @@ def fused_moe_router_large_bs_kernel(
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
# 5. top1
# 5. top1
cond
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
<
num_experts
arange_block_size_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
top1
=
tl
.
argmax
(
tl
.
where
(
cond
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
)
cond_top1
=
arange_block_size_n
<
num_experts
top1
=
tl
.
argmax
(
tl
.
where
(
cond_top1
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
)
top1_v
=
tl
.
max
(
top1_v
=
tl
.
max
(
tl
.
where
(
cond
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
,
keep_dims
=
True
tl
.
where
(
cond
_top1
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
,
keep_dims
=
True
)
)
invsumexp
=
1.0
/
tl
.
sum
(
top1_
invsumexp
=
1.0
/
tl
.
sum
(
tl
.
where
(
cond
,
tl
.
exp
(
logits_softcapped
-
top1_v
),
0.0
),
axis
=
1
tl
.
where
(
cond
_top1
,
tl
.
exp
(
logits_softcapped
-
top1_v
),
0.0
),
axis
=
1
)
)
# 6. store to output
# 6. store
top1
to output
offs_top
k
=
pid
*
topk
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_top
1
=
pid
*
topk
*
BLOCK_SIZE_M
+
topk
*
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
top
k
_mask
=
offs_top
k
<
bs
top
1
_mask
=
offs_top
1
<
bs
*
topk
tl
.
store
(
topk_ids_ptr
+
offs_top
k
,
top1
,
mask
=
top
k
_mask
)
tl
.
store
(
topk_ids_ptr
+
offs_top
1
,
top1
,
mask
=
top
1
_mask
)
tl
.
store
(
tl
.
store
(
topk_weights_ptr
+
offs_topk
,
topk_weights_ptr
+
offs_top1
,
invsumexp
,
top1_invsumexp
,
mask
=
topk_mask
,
mask
=
top1_mask
,
)
# 7. handle topk == 2
if
topk
==
2
:
cond_top2
=
(
arange_block_size_n
<
num_experts
)
and
(
arange_block_size_n
!=
top1
[:,
None
]
)
top2
=
tl
.
argmax
(
tl
.
where
(
cond_top2
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
,
keep_dims
=
True
,
)
top2_v
=
tl
.
sum
(
logits_softcapped
*
(
arange_block_size_n
==
top2
),
axis
=
1
,
keep_dims
=
True
)
top2_invsumexp
=
tl
.
exp
(
top2_v
-
top1_v
)
*
top1_invsumexp
[:,
None
]
# store top2
offs_top2
=
(
pid
*
topk
*
BLOCK_SIZE_M
+
topk
*
tl
.
arange
(
0
,
BLOCK_SIZE_M
)[:,
None
]
+
1
)
top2_mask
=
offs_top2
<
bs
*
topk
tl
.
store
(
topk_ids_ptr
+
offs_top2
,
top2
,
mask
=
top2_mask
)
tl
.
store
(
topk_weights_ptr
+
offs_top2
,
top2_invsumexp
,
mask
=
top2_mask
,
)
)
...
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
...
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
assert
num_experts
<=
BLOCK_SIZE_N
assert
num_experts
<=
BLOCK_SIZE_N
assert
hidden_dim
%
BLOCK_SIZE_K
==
0
assert
hidden_dim
%
BLOCK_SIZE_K
==
0
assert
topk
=
=
1
assert
topk
<
=
2
topk_weights
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
topk_weights
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
topk_ids
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
int32
,
device
=
x
.
device
)
topk_ids
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
int32
,
device
=
x
.
device
)
...
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
...
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
gating_output
,
gating_output
,
topk
,
topk
,
renormalize
,
renormalize
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
assert
not
renormalize
assert
not
renormalize
assert
(
assert
(
...
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
...
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
BLOCK_SIZE_K
=
256
BLOCK_SIZE_K
=
256
if
(
if
(
bs
>=
512
bs
>=
512
and
topk
=
=
1
and
topk
<
=
2
and
num_experts
<=
BLOCK_SIZE_N
and
num_experts
<=
BLOCK_SIZE_N
and
hidden_dim
%
BLOCK_SIZE_K
==
0
and
hidden_dim
%
BLOCK_SIZE_K
==
0
):
):
...
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
...
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
router_weight
=
gating_output
,
router_weight
=
gating_output
,
topk
=
topk
,
topk
=
topk
,
moe_softcapping
=
moe_softcapping
,
moe_softcapping
=
moe_softcapping
,
correction_bias
=
correction_bias
,
)
)
...
...
python/sglang/srt/managers/configure_logging.py
View file @
22352d47
...
@@ -28,7 +28,7 @@ if __name__ == "__main__":
...
@@ -28,7 +28,7 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--url"
,
type
=
str
,
default
=
"http://localhost:30000"
)
parser
.
add_argument
(
"--url"
,
type
=
str
,
default
=
"http://localhost:30000"
)
parser
.
add_argument
(
"--log-requests"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--log-requests"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--log-requests-level"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--log-requests-level"
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dump-requests-folder"
,
type
=
str
,
default
=
"/tmp/sglang_request_dump"
"--dump-requests-folder"
,
type
=
str
,
default
=
"/tmp/sglang_request_dump"
)
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
22352d47
...
@@ -516,9 +516,6 @@ class EmbeddingReqInput:
...
@@ -516,9 +516,6 @@ class EmbeddingReqInput:
# For cross-encoder requests
# For cross-encoder requests
is_cross_encoder_request
:
bool
=
False
is_cross_encoder_request
:
bool
=
False
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
def
normalize_batch_and_arguments
(
self
):
def
normalize_batch_and_arguments
(
self
):
# at least one of text, input_ids, or image should be provided
# at least one of text, input_ids, or image should be provided
if
self
.
text
is
None
and
self
.
input_ids
is
None
and
self
.
image_data
is
None
:
if
self
.
text
is
None
and
self
.
input_ids
is
None
and
self
.
image_data
is
None
:
...
@@ -572,6 +569,9 @@ class EmbeddingReqInput:
...
@@ -572,6 +569,9 @@ class EmbeddingReqInput:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
return
self
.
rid
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
if
self
.
is_cross_encoder_request
:
if
self
.
is_cross_encoder_request
:
return
EmbeddingReqInput
(
return
EmbeddingReqInput
(
...
...
python/sglang/srt/managers/mm_utils.py
View file @
22352d47
...
@@ -2,12 +2,15 @@
...
@@ -2,12 +2,15 @@
Multi-modality utils
Multi-modality utils
"""
"""
import
hashlib
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
Modality
,
MultimodalDataItem
,
MultimodalDataItem
,
...
@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
...
@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
# Convert valid pairs to tensor
# Convert valid pairs to tensor
valid_pairs_tensor
=
torch
.
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
valid_pairs_tensor
=
torch
.
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
return
valid_pairs_tensor
return
valid_pairs_tensor
def
data_hash
(
data
)
->
int
:
hash_bytes
=
hashlib
.
sha256
(
data
).
digest
()[:
8
]
return
int
.
from_bytes
(
hash_bytes
,
byteorder
=
"big"
,
signed
=
False
)
def
tensor_hash
(
tensor_list
)
->
int
:
"""
hash a tensor or a tensor list
"""
tensor
=
tensor_list
if
isinstance
(
tensor_list
,
list
):
tensor_list
=
flatten_nested_list
(
tensor_list
)
tensor_list
=
[
x
.
flatten
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
tensor_list
]
tensor
=
torch
.
concat
(
tensor_list
)
if
tensor
.
is_cuda
:
return
gpu_tensor_hash
(
tensor
)
tensor
=
tensor
.
detach
().
contiguous
()
if
tensor
.
dtype
==
torch
.
bfloat16
:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor
=
tensor
.
float
()
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
is_cuda
:
# TODO: improve this
tensor_cpu
=
tensor
.
cpu
()
else
:
tensor_cpu
=
tensor
mv
=
memoryview
(
tensor_cpu
.
numpy
())
return
data_hash
(
mv
.
tobytes
())
def
hash_feature
(
f
):
if
isinstance
(
f
,
list
):
if
isinstance
(
f
[
0
],
torch
.
Tensor
):
return
tensor_hash
(
f
)
return
data_hash
(
tuple
(
flatten_nested_list
(
f
)))
elif
isinstance
(
f
,
np
.
ndarray
):
arr
=
np
.
ascontiguousarray
(
f
)
arr_bytes
=
arr
.
tobytes
()
return
data_hash
(
arr_bytes
)
elif
isinstance
(
f
,
torch
.
Tensor
):
return
tensor_hash
([
f
])
return
data_hash
(
f
)
python/sglang/srt/managers/multimodal_processor.py
View file @
22352d47
...
@@ -3,7 +3,6 @@ import importlib
...
@@ -3,7 +3,6 @@ import importlib
import
inspect
import
inspect
import
logging
import
logging
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
sglang.srt.multimodal.processors.base_processor
import
BaseMultimodalProcessor
from
sglang.srt.multimodal.processors.base_processor
import
BaseMultimodalProcessor
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
22352d47
...
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
...
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import
copy
import
copy
import
dataclasses
import
dataclasses
import
hashlib
import
logging
import
logging
import
threading
import
threading
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
...
@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
...
@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin
,
ScheduleBatchDisaggregationDecodeMixin
,
)
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
,
SWAChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
,
SWAChunkCache
...
@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"max_micro_batch_size"
,
"max_micro_batch_size"
,
"disable_shared_experts_fusion"
,
"disable_shared_experts_fusion"
,
"sampling_backend"
,
"sampling_backend"
,
"speculative_accept_threshold_acc"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_acc"
,
"torchao_config"
,
"torchao_config"
,
"triton_attention_reduce_in_fp32"
,
"triton_attention_reduce_in_fp32"
,
"num_reserved_decode_tokens"
,
"num_reserved_decode_tokens"
,
...
@@ -180,7 +178,9 @@ class Modality(Enum):
...
@@ -180,7 +178,9 @@ class Modality(Enum):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MultimodalDataItem
:
class
MultimodalDataItem
:
"""
"""
A single multimodal data, from a single image/video/audio or others.
One MultimodalDataItem contains all inputs for one modality.
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio.
We put the common fields first and the model-specific fields last.
We put the common fields first and the model-specific fields last.
"""
"""
...
@@ -232,53 +232,7 @@ class MultimodalDataItem:
...
@@ -232,53 +232,7 @@ class MultimodalDataItem:
"""
"""
Set the pad value after first hashing the data
Set the pad value after first hashing the data
"""
"""
from
sglang.srt.managers.mm_utils
import
hash_feature
def
data_hash
(
data
)
->
int
:
hash_bytes
=
hashlib
.
sha256
(
data
).
digest
()[:
8
]
return
int
.
from_bytes
(
hash_bytes
,
byteorder
=
"big"
,
signed
=
False
)
def
tensor_hash
(
tensor_list
)
->
int
:
"""
hash a tensor or a tensor list
"""
tensor
=
tensor_list
if
isinstance
(
tensor_list
,
list
):
tensor_list
=
flatten_nested_list
(
tensor_list
)
tensor_list
=
[
x
.
flatten
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
tensor_list
]
tensor
=
torch
.
concat
(
tensor_list
)
if
tensor
.
is_cuda
:
return
gpu_tensor_hash
(
tensor
)
tensor
=
tensor
.
detach
().
contiguous
()
if
tensor
.
dtype
==
torch
.
bfloat16
:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor
=
tensor
.
float
()
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
is_cuda
:
# TODO: improve this
tensor_cpu
=
tensor
.
cpu
()
else
:
tensor_cpu
=
tensor
mv
=
memoryview
(
tensor_cpu
.
numpy
())
return
data_hash
(
mv
.
tobytes
())
def
hash_feature
(
f
):
if
isinstance
(
f
,
list
):
if
isinstance
(
f
[
0
],
torch
.
Tensor
):
return
tensor_hash
(
f
)
return
data_hash
(
tuple
(
flatten_nested_list
(
f
)))
elif
isinstance
(
f
,
np
.
ndarray
):
arr
=
np
.
ascontiguousarray
(
f
)
arr_bytes
=
arr
.
tobytes
()
return
data_hash
(
arr_bytes
)
elif
isinstance
(
f
,
torch
.
Tensor
):
return
tensor_hash
([
f
])
return
data_hash
(
f
)
if
self
.
precomputed_features
is
not
None
:
if
self
.
precomputed_features
is
not
None
:
self
.
hash
=
hash_feature
(
self
.
precomputed_features
)
self
.
hash
=
hash_feature
(
self
.
precomputed_features
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
22352d47
...
@@ -418,14 +418,16 @@ class Scheduler(
...
@@ -418,14 +418,16 @@ class Scheduler(
self
.
last_decode_stats_tic
=
time
.
perf_counter
()
self
.
last_decode_stats_tic
=
time
.
perf_counter
()
self
.
last_prefill_stats_tic
=
time
.
perf_counter
()
self
.
last_prefill_stats_tic
=
time
.
perf_counter
()
self
.
return_health_check_ct
=
0
self
.
return_health_check_ct
=
0
self
.
num_retracted_reqs
:
int
=
0
self
.
num_paused_reqs
:
int
=
0
self
.
kv_transfer_speed_gb_s
:
float
=
0.0
self
.
kv_transfer_latency_ms
:
float
=
0.0
self
.
sessions
:
Dict
[
str
,
Session
]
=
{}
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
if
self
.
device
==
"cpu"
:
if
self
.
device
==
"cpu"
:
self
.
current_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
current_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
forward_sleep_time
=
None
self
.
forward_sleep_time
=
None
# Init session info
self
.
sessions
:
Dict
[
str
,
Session
]
=
{}
# Init chunked prefill
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
if
self
.
chunked_prefill_size
<=
0
:
# -1 means disable
if
self
.
chunked_prefill_size
<=
0
:
# -1 means disable
...
@@ -473,26 +475,12 @@ class Scheduler(
...
@@ -473,26 +475,12 @@ class Scheduler(
t
=
threading
.
Thread
(
target
=
self
.
watchdog_thread
,
daemon
=
True
)
t
=
threading
.
Thread
(
target
=
self
.
watchdog_thread
,
daemon
=
True
)
t
.
start
()
t
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
# Init memory saver, profiler and metric stats
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
server_args
.
enable_memory_saver
enable
=
server_args
.
enable_memory_saver
)
)
self
.
init_profier
()
# Init profiler
self
.
torch_profiler
=
None
self
.
torch_profiler_output_dir
:
Optional
[
str
]
=
None
self
.
profiler_activities
:
Optional
[
List
[
str
]]
=
None
self
.
profile_id
:
Optional
[
str
]
=
None
self
.
profiler_target_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_decode_ct
:
Optional
[
int
]
=
None
self
.
profiler_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_decode_ct
:
Optional
[
int
]
=
None
self
.
profile_by_stage
:
bool
=
False
self
.
profile_steps
:
Optional
[
int
]
=
None
self
.
profile_in_progress
:
bool
=
False
self
.
rpd_profiler
=
None
# Init metrics stats
self
.
init_metrics
()
self
.
init_metrics
()
self
.
init_kv_events
(
server_args
.
kv_events_config
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
...
@@ -526,6 +514,7 @@ class Scheduler(
...
@@ -526,6 +514,7 @@ class Scheduler(
]
]
)
)
# Init disaggregation
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
self
.
server_args
.
disaggregation_mode
)
)
...
@@ -624,6 +613,21 @@ class Scheduler(
...
@@ -624,6 +613,21 @@ class Scheduler(
)
)
)
)
def
init_profier
(
self
):
self
.
torch_profiler
=
None
self
.
torch_profiler_output_dir
:
Optional
[
str
]
=
None
self
.
profiler_activities
:
Optional
[
List
[
str
]]
=
None
self
.
profile_id
:
Optional
[
str
]
=
None
self
.
profiler_target_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_decode_ct
:
Optional
[
int
]
=
None
self
.
profiler_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_decode_ct
:
Optional
[
int
]
=
None
self
.
profile_by_stage
:
bool
=
False
self
.
profile_steps
:
Optional
[
int
]
=
None
self
.
profile_in_progress
:
bool
=
False
self
.
rpd_profiler
=
None
def
init_metrics
(
self
):
def
init_metrics
(
self
):
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
...
@@ -2107,6 +2111,18 @@ class Scheduler(
...
@@ -2107,6 +2111,18 @@ class Scheduler(
def
get_internal_state
(
self
,
recv_req
:
GetInternalStateReq
):
def
get_internal_state
(
self
,
recv_req
:
GetInternalStateReq
):
ret
=
dict
(
global_server_args_dict
)
ret
=
dict
(
global_server_args_dict
)
ret
[
"last_gen_throughput"
]
=
self
.
last_gen_throughput
ret
[
"last_gen_throughput"
]
=
self
.
last_gen_throughput
ret
[
"memory_usage"
]
=
{
"weight"
:
round
(
self
.
tp_worker
.
worker
.
model_runner
.
weight_load_mem_usage
,
2
),
"kvcache"
:
round
(
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
mem_usage
,
2
),
"cuda_graph"
:
round
(
self
.
tp_worker
.
worker
.
model_runner
.
cuda_graph_mem_usage
,
2
),
"token_capacity"
:
int
(
self
.
max_total_num_tokens
),
}
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_spec_accept_count
>
0
:
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_spec_accept_count
>
0
:
ret
[
"avg_spec_accept_length"
]
=
(
ret
[
"avg_spec_accept_length"
]
=
(
self
.
cum_spec_accept_length
/
self
.
cum_spec_accept_count
self
.
cum_spec_accept_length
/
self
.
cum_spec_accept_count
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
22352d47
...
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
...
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
stream_interval
=
(
stream_interval
=
(
req
.
sampling_params
.
stream_interval
or
self
.
stream_interval
req
.
sampling_params
.
stream_interval
or
self
.
stream_interval
)
)
should_output
=
len
(
req
.
output_ids
)
%
stream_interval
==
0
should_output
=
(
len
(
req
.
output_ids
)
%
stream_interval
==
1
if
not
self
.
model_config
.
is_multimodal_gen
and
stream_interval
>
1
else
len
(
req
.
output_ids
)
%
stream_interval
==
0
)
else
:
else
:
should_output
=
(
should_output
=
(
len
(
req
.
output_ids
)
%
DEFAULT_FORCE_STREAM_INTERVAL
==
0
len
(
req
.
output_ids
)
%
DEFAULT_FORCE_STREAM_INTERVAL
==
0
and
not
self
.
model_config
.
is_multimodal_gen
if
not
self
.
model_config
.
is_multimodal_gen
else
False
)
)
if
should_output
:
if
should_output
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
22352d47
...
@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
UpdateWeightsFromTensorReqOutput
,
)
)
from
sglang.srt.managers.multimodal_processor
import
(
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
get_dummy_processor
,
get_mm_processor
,
import_processors
,
)
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -187,6 +183,8 @@ class TokenizerManager:
...
@@ -187,6 +183,8 @@ class TokenizerManager:
if
server_args
.
preferred_sampling_params
if
server_args
.
preferred_sampling_params
else
None
else
None
)
)
self
.
crash_dump_folder
=
server_args
.
crash_dump_folder
self
.
crash_dump_performed
=
False
# Flag to ensure dump is only called once
# Init inter-process communication
# Init inter-process communication
context
=
zmq
.
asyncio
.
Context
(
2
)
context
=
zmq
.
asyncio
.
Context
(
2
)
...
@@ -251,10 +249,11 @@ class TokenizerManager:
...
@@ -251,10 +249,11 @@ class TokenizerManager:
self
.
dump_requests_folder
=
""
# By default do not dump
self
.
dump_requests_folder
=
""
# By default do not dump
self
.
dump_requests_threshold
=
1000
self
.
dump_requests_threshold
=
1000
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
crash_dump_request_list
:
deque
[
Tuple
]
=
deque
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
asyncio_tasks
=
set
()
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
max_req_input_len
=
None
self
.
max_req_input_len
=
None
self
.
asyncio_tasks
=
set
()
# The event to notify the weight sync is finished.
# The event to notify the weight sync is finished.
self
.
model_update_lock
=
RWLock
()
self
.
model_update_lock
=
RWLock
()
...
@@ -266,14 +265,14 @@ class TokenizerManager:
...
@@ -266,14 +265,14 @@ class TokenizerManager:
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
self
.
server_args
.
disaggregation_mode
)
)
self
.
transfer_backend
=
TransferBackend
(
self
.
disaggregation_
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
self
.
server_args
.
disaggregation_transfer_backend
)
)
# Start kv boostrap server on prefill
# Start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
=
get_kv_class
(
kv_bootstrap_server_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
self
.
disaggregation_
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
)
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
server_args
.
disaggregation_bootstrap_port
self
.
server_args
.
disaggregation_bootstrap_port
...
@@ -324,7 +323,6 @@ class TokenizerManager:
...
@@ -324,7 +323,6 @@ class TokenizerManager:
self
.
profile_communicator
=
_Communicator
(
self
.
profile_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
self
.
health_check_communitcator
=
_Communicator
(
self
.
send_to_scheduler
,
1
)
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
...
@@ -484,7 +482,7 @@ class TokenizerManager:
...
@@ -484,7 +482,7 @@ class TokenizerManager:
token_type_ids
=
encoded
.
get
(
"token_type_ids"
,
[
None
])[
0
]
token_type_ids
=
encoded
.
get
(
"token_type_ids"
,
[
None
])[
0
]
if
self
.
mm_processor
and
obj
.
contains_mm_input
():
if
self
.
mm_processor
and
obj
.
contains_mm_input
():
image_inputs
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_inputs
:
Dict
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_data
=
obj
.
image_data
,
image_data
=
obj
.
image_data
,
input_text
=
input_text
or
input_ids
,
input_text
=
input_text
or
input_ids
,
request_obj
=
obj
,
request_obj
=
obj
,
...
@@ -547,6 +545,14 @@ class TokenizerManager:
...
@@ -547,6 +545,14 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature."
"Please set `--enable-custom-logits-processor` to enable this feature."
)
)
def
_validate_input_ids_in_vocab
(
self
,
input_ids
:
List
[
int
],
vocab_size
:
int
)
->
None
:
if
any
(
id
>=
vocab_size
for
id
in
input_ids
):
raise
ValueError
(
f
"The input_ids
{
input_ids
}
contains values greater than the vocab size (
{
vocab_size
}
)."
)
def
_create_tokenized_object
(
def
_create_tokenized_object
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
@@ -1096,12 +1102,36 @@ class TokenizerManager:
...
@@ -1096,12 +1102,36 @@ class TokenizerManager:
"image_data"
,
"image_data"
,
"audio_data"
,
"audio_data"
,
"lora_path"
,
"lora_path"
,
"sampling_params"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
]
]
)
)
out_skip_names
=
set
([
"text"
,
"output_ids"
,
"embedding"
])
elif
self
.
log_requests_level
==
1
:
elif
self
.
log_requests_level
==
1
:
max_length
=
2048
max_length
=
1
<<
30
skip_names
=
set
(
[
"text"
,
"input_ids"
,
"input_embeds"
,
"image_data"
,
"audio_data"
,
"lora_path"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
]
)
elif
self
.
log_requests_level
==
2
:
elif
self
.
log_requests_level
==
2
:
max_length
=
2048
elif
self
.
log_requests_level
==
3
:
max_length
=
1
<<
30
max_length
=
1
<<
30
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1118,6 +1148,8 @@ class TokenizerManager:
...
@@ -1118,6 +1148,8 @@ class TokenizerManager:
self
.
dump_requests_folder
=
obj
.
dump_requests_folder
self
.
dump_requests_folder
=
obj
.
dump_requests_folder
if
obj
.
dump_requests_threshold
is
not
None
:
if
obj
.
dump_requests_threshold
is
not
None
:
self
.
dump_requests_threshold
=
obj
.
dump_requests_threshold
self
.
dump_requests_threshold
=
obj
.
dump_requests_threshold
if
obj
.
crash_dump_folder
is
not
None
:
self
.
crash_dump_folder
=
obj
.
crash_dump_folder
logging
.
info
(
f
"Config logging:
{
obj
=
}
"
)
logging
.
info
(
f
"Config logging:
{
obj
=
}
"
)
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
...
@@ -1166,6 +1198,52 @@ class TokenizerManager:
...
@@ -1166,6 +1198,52 @@ class TokenizerManager:
loop
.
create_task
(
print_exception_wrapper
(
self
.
sigterm_watchdog
))
loop
.
create_task
(
print_exception_wrapper
(
self
.
sigterm_watchdog
))
)
)
def
dump_requests_before_crash
(
self
):
if
self
.
crash_dump_performed
:
logger
.
info
(
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger
.
error
(
f
"Dumping requests before crash.
{
self
.
crash_dump_folder
=
}
"
)
self
.
crash_dump_performed
=
True
if
not
self
.
crash_dump_folder
:
return
data_to_dump
=
[]
if
self
.
crash_dump_request_list
:
data_to_dump
.
extend
(
self
.
crash_dump_request_list
)
# Add unfinished requests from rid_to_state
unfinished_requests
=
[]
for
rid
,
state
in
self
.
rid_to_state
.
items
():
if
not
state
.
finished
:
unfinished_requests
.
append
(
(
state
.
obj
,
{},
state
.
created_time
,
time
.
time
())
)
if
unfinished_requests
:
data_to_dump
.
extend
(
unfinished_requests
)
if
not
data_to_dump
:
return
filename
=
os
.
path
.
join
(
self
.
crash_dump_folder
,
os
.
getenv
(
"HOSTNAME"
,
None
),
f
'crash_dump_
{
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
}
.pkl'
,
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
# Include server_args in the dump
data_to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
data_to_dump
,
}
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
data_to_dump_with_server_args
,
f
)
logger
.
error
(
f
"Dumped
{
len
(
self
.
crash_dump_request_list
)
}
finished and
{
len
(
unfinished_requests
)
}
unfinished requests before crash to
{
filename
}
"
)
async
def
sigterm_watchdog
(
self
):
async
def
sigterm_watchdog
(
self
):
while
not
self
.
gracefully_exit
:
while
not
self
.
gracefully_exit
:
await
asyncio
.
sleep
(
5
)
await
asyncio
.
sleep
(
5
)
...
@@ -1175,11 +1253,12 @@ class TokenizerManager:
...
@@ -1175,11 +1253,12 @@ class TokenizerManager:
remain_num_req
=
len
(
self
.
rid_to_state
)
remain_num_req
=
len
(
self
.
rid_to_state
)
if
self
.
health_check_failed
:
if
self
.
health_check_failed
:
# if health check failed, exit immediately
# if health check failed,
we should
exit immediately
logger
.
error
(
logger
.
error
(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d"
,
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d"
,
remain_num_req
,
remain_num_req
,
)
)
self
.
dump_requests_before_crash
()
break
break
elif
get_bool_env_var
(
"SGL_FORCE_SHUTDOWN"
):
elif
get_bool_env_var
(
"SGL_FORCE_SHUTDOWN"
):
...
@@ -1196,6 +1275,7 @@ class TokenizerManager:
...
@@ -1196,6 +1275,7 @@ class TokenizerManager:
if
remain_num_req
>
0
:
if
remain_num_req
>
0
:
await
asyncio
.
sleep
(
5
)
await
asyncio
.
sleep
(
5
)
else
:
else
:
self
.
dump_requests_before_crash
()
break
break
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
...
@@ -1273,16 +1353,7 @@ class TokenizerManager:
...
@@ -1273,16 +1353,7 @@ class TokenizerManager:
"meta_info"
:
meta_info
,
"meta_info"
:
meta_info
,
}
}
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
):
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
):
if
isinstance
(
recv_obj
.
outputs
[
i
],
str
):
raise
NotImplementedError
(
"BatchMultimodalOut not implemented"
)
out_dict
=
{
"text"
:
recv_obj
.
outputs
[
i
],
"meta_info"
:
meta_info
,
}
else
:
out_dict
=
{
"outputs"
:
json
.
dumps
(
recv_obj
.
outputs
[
i
]),
"meta_info"
:
meta_info
,
}
else
:
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
out_dict
=
{
...
@@ -1306,6 +1377,8 @@ class TokenizerManager:
...
@@ -1306,6 +1377,8 @@ class TokenizerManager:
self
.
collect_metrics
(
state
,
recv_obj
,
i
)
self
.
collect_metrics
(
state
,
recv_obj
,
i
)
if
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
:
if
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
:
self
.
dump_requests
(
state
,
out_dict
)
self
.
dump_requests
(
state
,
out_dict
)
if
self
.
crash_dump_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
:
self
.
record_request_for_crash_dump
(
state
,
out_dict
)
def
convert_logprob_style
(
def
convert_logprob_style
(
self
,
self
,
...
@@ -1317,6 +1390,9 @@ class TokenizerManager:
...
@@ -1317,6 +1390,9 @@ class TokenizerManager:
recv_obj
:
BatchStrOut
,
recv_obj
:
BatchStrOut
,
recv_obj_index
:
int
,
recv_obj_index
:
int
,
):
):
if
recv_obj
.
input_token_logprobs_val
is
None
:
return
if
len
(
recv_obj
.
input_token_logprobs_val
)
>
0
:
if
len
(
recv_obj
.
input_token_logprobs_val
)
>
0
:
state
.
input_token_logprobs_val
.
extend
(
state
.
input_token_logprobs_val
.
extend
(
recv_obj
.
input_token_logprobs_val
[
recv_obj_index
]
recv_obj
.
input_token_logprobs_val
[
recv_obj_index
]
...
@@ -1436,7 +1512,10 @@ class TokenizerManager:
...
@@ -1436,7 +1512,10 @@ class TokenizerManager:
else
0
else
0
)
)
if
state
.
first_token_time
==
0.0
:
if
(
state
.
first_token_time
==
0.0
and
self
.
disaggregation_mode
!=
DisaggregationMode
.
PREFILL
):
state
.
first_token_time
=
state
.
last_time
=
time
.
time
()
state
.
first_token_time
=
state
.
last_time
=
time
.
time
()
state
.
last_completion_tokens
=
completion_tokens
state
.
last_completion_tokens
=
completion_tokens
self
.
metrics_collector
.
observe_time_to_first_token
(
self
.
metrics_collector
.
observe_time_to_first_token
(
...
@@ -1484,14 +1563,31 @@ class TokenizerManager:
...
@@ -1484,14 +1563,31 @@ class TokenizerManager:
to_dump
=
self
.
dump_request_list
to_dump
=
self
.
dump_request_list
self
.
dump_request_list
=
[]
self
.
dump_request_list
=
[]
to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
to_dump
,
}
def
background_task
():
def
background_task
():
os
.
makedirs
(
self
.
dump_requests_folder
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
dump_requests_folder
,
exist_ok
=
True
)
with
open
(
filename
,
"wb"
)
as
f
:
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
to_dump
,
f
)
pickle
.
dump
(
to_dump
_with_server_args
,
f
)
# Schedule the task to run in the background without awaiting it
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
record_request_for_crash_dump
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
current_time
=
time
.
time
()
self
.
crash_dump_request_list
.
append
(
(
state
.
obj
,
out_dict
,
state
.
created_time
,
current_time
)
)
# Remove requests older than 5 minutes based on finish time
while
(
self
.
crash_dump_request_list
and
current_time
-
self
.
crash_dump_request_list
[
0
][
3
]
>=
300
):
self
.
crash_dump_request_list
.
popleft
()
def
_handle_abort_req
(
self
,
recv_obj
):
def
_handle_abort_req
(
self
,
recv_obj
):
self
.
rid_to_state
.
pop
(
recv_obj
.
rid
,
None
)
self
.
rid_to_state
.
pop
(
recv_obj
.
rid
,
None
)
...
@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
...
@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
except
Exception
:
except
Exception
:
traceback
=
get_exception_traceback
()
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"TokenizerManager hit an exception:
{
traceback
}
"
)
logger
.
error
(
f
"TokenizerManager hit an exception:
{
traceback
}
"
)
if
hasattr
(
func
,
"__self__"
)
and
isinstance
(
func
.
__self__
,
TokenizerManager
):
func
.
__self__
.
dump_requests_before_crash
()
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
...
@@ -1632,6 +1730,7 @@ class SignalHandler:
...
@@ -1632,6 +1730,7 @@ class SignalHandler:
logger
.
error
(
logger
.
error
(
"Received sigquit from a child process. It usually means the child failed."
"Received sigquit from a child process. It usually means the child failed."
)
)
self
.
tokenizer_manager
.
dump_requests_before_crash
()
kill_process_tree
(
os
.
getpid
())
kill_process_tree
(
os
.
getpid
())
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
22352d47
...
@@ -123,6 +123,7 @@ class KVCache(abc.ABC):
...
@@ -123,6 +123,7 @@ class KVCache(abc.ABC):
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
enable
=
enable_memory_saver
)
)
self
.
mem_usage
=
0
# used for chunked cpu-offloading
# used for chunked cpu-offloading
self
.
cpu_offloading_chunk_size
=
8192
self
.
cpu_offloading_chunk_size
=
8192
...
@@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache):
...
@@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache):
logger
.
info
(
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
)
)
self
.
mem_usage
=
(
k_size
+
v_size
)
/
GB
def
_create_buffers
(
self
):
def
_create_buffers
(
self
):
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
...
@@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache):
...
@@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache):
logger
.
info
(
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, KV size:
{
kv_size
/
GB
:.
2
f
}
GB"
f
"KV Cache is allocated. #tokens:
{
size
}
, KV size:
{
kv_size
/
GB
:.
2
f
}
GB"
)
)
self
.
mem_usage
=
kv_size
/
GB
def
get_kv_size_bytes
(
self
):
def
get_kv_size_bytes
(
self
):
assert
hasattr
(
self
,
"kv_buffer"
)
assert
hasattr
(
self
,
"kv_buffer"
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
22352d47
...
@@ -604,12 +604,13 @@ class ModelRunner:
...
@@ -604,12 +604,13 @@ class ModelRunner:
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
after_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
after_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
weight_load_mem_usage
=
before_avail_memory
-
after_avail_memory
logger
.
info
(
logger
.
info
(
f
"Load weight end. "
f
"Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"avail mem=
{
after_avail_memory
:.
2
f
}
GB, "
f
"avail mem=
{
after_avail_memory
:.
2
f
}
GB, "
f
"mem usage=
{
(
before_avail_memory
-
after_avail_memory
)
:.
2
f
}
GB."
f
"mem usage=
{
self
.
weight_load_mem_usage
:.
2
f
}
GB."
)
)
# Handle the case where some ranks do not finish loading.
# Handle the case where some ranks do not finish loading.
...
@@ -1250,6 +1251,7 @@ class ModelRunner:
...
@@ -1250,6 +1251,7 @@ class ModelRunner:
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
"""Capture cuda graphs."""
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_mem_usage
=
0
if
not
self
.
is_generation
:
if
not
self
.
is_generation
:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
...
@@ -1265,9 +1267,10 @@ class ModelRunner:
...
@@ -1265,9 +1267,10 @@ class ModelRunner:
)
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
cuda_graph_mem_usage
=
before_mem
-
after_mem
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
perf_counter
()
-
tic
:.
2
f
}
s. "
f
"Capture cuda graph end. Time elapsed:
{
time
.
perf_counter
()
-
tic
:.
2
f
}
s. "
f
"mem usage=
{
(
before_mem
-
after_mem
)
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
f
"mem usage=
{
self
.
cuda_graph_mem_usage
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
)
)
def
apply_torch_tp
(
self
):
def
apply_torch_tp
(
self
):
...
...
python/sglang/srt/model_loader/loader.py
View file @
22352d47
...
@@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader):
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
if
get_bool_env_var
(
"SGL_CPU_QUANTIZATION"
):
return
load_model_with_cpu_quantization
(
self
,
model_config
=
model_config
,
device_config
=
device_config
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model
=
_initialize_model
(
...
@@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader):
...
@@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader):
return
model
.
eval
()
return
model
.
eval
()
def
load_model_with_cpu_quantization
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
)
if
not
isinstance
(
self
,
DummyModelLoader
):
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
model
.
to
(
target_device
)
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
"""Get a model loader based on the load format."""
...
...
python/sglang/srt/models/mistral.py
View file @
22352d47
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# ==============================================================================
# ==============================================================================
"""Inference-only Mistral model."""
"""Inference-only Mistral model."""
from
typing
import
List
,
Union
from
typing
import
List
import
torch
import
torch
from
transformers.models.mistral3.modeling_mistral3
import
Mistral3MultiModalProjector
from
transformers.models.mistral3.modeling_mistral3
import
Mistral3MultiModalProjector
...
...
python/sglang/srt/server_args.py
View file @
22352d47
...
@@ -99,6 +99,7 @@ class ServerArgs:
...
@@ -99,6 +99,7 @@ class ServerArgs:
log_level_http
:
Optional
[
str
]
=
None
log_level_http
:
Optional
[
str
]
=
None
log_requests
:
bool
=
False
log_requests
:
bool
=
False
log_requests_level
:
int
=
0
log_requests_level
:
int
=
0
crash_dump_folder
:
Optional
[
str
]
=
None
show_time_cost
:
bool
=
False
show_time_cost
:
bool
=
False
enable_metrics
:
bool
=
False
enable_metrics
:
bool
=
False
bucket_time_to_first_token
:
Optional
[
List
[
float
]]
=
None
bucket_time_to_first_token
:
Optional
[
List
[
float
]]
=
None
...
@@ -927,8 +928,14 @@ class ServerArgs:
...
@@ -927,8 +928,14 @@ class ServerArgs:
"--log-requests-level"
,
"--log-requests-level"
,
type
=
int
,
type
=
int
,
default
=
0
,
default
=
0
,
help
=
"0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output."
,
help
=
"0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output."
,
choices
=
[
0
,
1
,
2
],
choices
=
[
0
,
1
,
2
,
3
],
)
parser
.
add_argument
(
"--crash-dump-folder"
,
type
=
str
,
default
=
ServerArgs
.
crash_dump_folder
,
help
=
"Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--show-time-cost"
,
"--show-time-cost"
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment