Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
ebe58d54
Unverified
Commit
ebe58d54
authored
May 18, 2025
by
Chang Su
Committed by
GitHub
May 18, 2025
Browse files
[Misc] Implement RankZeroFilter for rank-specific logging in model_runner.py (#6333)
parent
066cf445
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
36 deletions
+42
-36
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+42
-36
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
ebe58d54
...
@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
...
@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
RankZeroFilter
(
logging
.
Filter
):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
def
__init__
(
self
,
is_rank_zero
):
super
().
__init__
()
self
.
is_rank_zero
=
is_rank_zero
def
filter
(
self
,
record
):
if
record
.
levelno
==
logging
.
INFO
:
return
self
.
is_rank_zero
return
True
class
ModelRunner
:
class
ModelRunner
:
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
...
@@ -126,6 +139,10 @@ class ModelRunner:
...
@@ -126,6 +139,10 @@ class ModelRunner:
self
.
mem_fraction_static
=
mem_fraction_static
self
.
mem_fraction_static
=
mem_fraction_static
self
.
device
=
server_args
.
device
self
.
device
=
server_args
.
device
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
# Apply the rank zero filter to logger
if
not
any
(
isinstance
(
f
,
RankZeroFilter
)
for
f
in
logger
.
filters
):
logger
.
addFilter
(
RankZeroFilter
(
tp_rank
==
0
))
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
self
.
pp_rank
=
pp_rank
self
.
pp_rank
=
pp_rank
...
@@ -135,7 +152,6 @@ class ModelRunner:
...
@@ -135,7 +152,6 @@ class ModelRunner:
self
.
is_draft_worker
=
is_draft_worker
self
.
is_draft_worker
=
is_draft_worker
self
.
is_generation
=
model_config
.
is_generation
self
.
is_generation
=
model_config
.
is_generation
self
.
is_multimodal
=
model_config
.
is_multimodal
self
.
is_multimodal
=
model_config
.
is_multimodal
self
.
should_log
=
tp_rank
==
0
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
...
@@ -281,10 +297,9 @@ class ModelRunner:
...
@@ -281,10 +297,9 @@ class ModelRunner:
server_args
.
attention_backend
=
"fa3"
server_args
.
attention_backend
=
"fa3"
else
:
else
:
server_args
.
attention_backend
=
"triton"
server_args
.
attention_backend
=
"triton"
if
self
.
should_log
:
logger
.
info
(
logger
.
info
(
f
"Attention backend not set. Use
{
server_args
.
attention_backend
}
backend by default."
f
"Attention backend not set. Use
{
server_args
.
attention_backend
}
backend by default."
)
)
elif
self
.
use_mla_backend
:
elif
self
.
use_mla_backend
:
if
server_args
.
device
!=
"cpu"
:
if
server_args
.
device
!=
"cpu"
:
if
server_args
.
attention_backend
in
[
if
server_args
.
attention_backend
in
[
...
@@ -294,10 +309,9 @@ class ModelRunner:
...
@@ -294,10 +309,9 @@ class ModelRunner:
"flashmla"
,
"flashmla"
,
"cutlass_mla"
,
"cutlass_mla"
,
]:
]:
if
self
.
should_log
:
logger
.
info
(
logger
.
info
(
f
"MLA optimization is turned on. Use
{
server_args
.
attention_backend
}
backend."
f
"MLA optimization is turned on. Use
{
server_args
.
attention_backend
}
backend."
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid attention backend for MLA:
{
server_args
.
attention_backend
}
"
f
"Invalid attention backend for MLA:
{
server_args
.
attention_backend
}
"
...
@@ -316,10 +330,9 @@ class ModelRunner:
...
@@ -316,10 +330,9 @@ class ModelRunner:
server_args
.
attention_backend
=
"triton"
server_args
.
attention_backend
=
"triton"
if
server_args
.
enable_double_sparsity
:
if
server_args
.
enable_double_sparsity
:
if
self
.
should_log
:
logger
.
info
(
logger
.
info
(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
)
server_args
.
attention_backend
=
"triton"
server_args
.
attention_backend
=
"triton"
server_args
.
disable_cuda_graph
=
True
server_args
.
disable_cuda_graph
=
True
if
server_args
.
ds_heavy_channel_type
is
None
:
if
server_args
.
ds_heavy_channel_type
is
None
:
...
@@ -330,26 +343,22 @@ class ModelRunner:
...
@@ -330,26 +343,22 @@ class ModelRunner:
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
self
.
mem_fraction_static
*=
0.90
self
.
mem_fraction_static
*=
0.90
if
self
.
should_log
:
logger
.
info
(
logger
.
info
(
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
because this is a multimodal model."
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
"
)
f
"because this is a multimodal model."
)
logger
.
info
(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
chunked_prefill_size
=
-
1
logger
.
info
(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
if
not
self
.
use_mla_backend
:
if
not
self
.
use_mla_backend
:
server_args
.
disable_chunked_prefix_cache
=
True
server_args
.
disable_chunked_prefix_cache
=
True
elif
self
.
page_size
>
1
:
elif
self
.
page_size
>
1
:
if
self
.
should_log
:
logger
.
info
(
"Disable chunked prefix cache when page size > 1."
)
logger
.
info
(
"Disable chunked prefix cache when page size > 1."
)
server_args
.
disable_chunked_prefix_cache
=
True
server_args
.
disable_chunked_prefix_cache
=
True
if
not
server_args
.
disable_chunked_prefix_cache
:
if
not
server_args
.
disable_chunked_prefix_cache
:
if
self
.
should_log
:
logger
.
info
(
"Chunked prefix cache is turned on."
)
logger
.
info
(
"Chunked prefix cache is turned on."
)
def
init_torch_distributed
(
self
):
def
init_torch_distributed
(
self
):
logger
.
info
(
"Init torch distributed begin."
)
logger
.
info
(
"Init torch distributed begin."
)
...
@@ -446,10 +455,9 @@ class ModelRunner:
...
@@ -446,10 +455,9 @@ class ModelRunner:
torch
.
set_num_threads
(
1
)
torch
.
set_num_threads
(
1
)
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
self
.
should_log
:
logger
.
info
(
logger
.
info
(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
)
self
.
server_args
.
dtype
=
"float16"
self
.
server_args
.
dtype
=
"float16"
self
.
model_config
.
dtype
=
torch
.
float16
self
.
model_config
.
dtype
=
torch
.
float16
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
...
@@ -485,11 +493,10 @@ class ModelRunner:
...
@@ -485,11 +493,10 @@ class ModelRunner:
self
.
model
.
load_kv_cache_scales
(
self
.
model
.
load_kv_cache_scales
(
self
.
server_args
.
quantization_param_path
self
.
server_args
.
quantization_param_path
)
)
if
self
.
should_log
:
logger
.
info
(
logger
.
info
(
"Loaded KV cache scaling factors from %s"
,
"Loaded KV cache scaling factors from %s"
,
self
.
server_args
.
quantization_param_path
,
self
.
server_args
.
quantization_param_path
,
)
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Using FP8 KV cache and scaling factors provided but "
"Using FP8 KV cache and scaling factors provided but "
...
@@ -1027,8 +1034,7 @@ class ModelRunner:
...
@@ -1027,8 +1034,7 @@ class ModelRunner:
)
)
def
apply_torch_tp
(
self
):
def
apply_torch_tp
(
self
):
if
self
.
should_log
:
logger
.
info
(
f
"Enabling torch tensor parallelism on
{
self
.
tp_size
}
devices."
)
logger
.
info
(
f
"Enabling torch tensor parallelism on
{
self
.
tp_size
}
devices."
)
from
sglang.srt.model_parallel
import
tensor_parallel
from
sglang.srt.model_parallel
import
tensor_parallel
device_mesh
=
torch
.
distributed
.
init_device_mesh
(
self
.
device
,
(
self
.
tp_size
,))
device_mesh
=
torch
.
distributed
.
init_device_mesh
(
self
.
device
,
(
self
.
tp_size
,))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment