Unverified Commit ebe58d54 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Misc] Implement RankZeroFilter for rank-specific logging in model_runner.py (#6333)

parent 066cf445
...@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 ...@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
def __init__(self, is_rank_zero):
super().__init__()
self.is_rank_zero = is_rank_zero
def filter(self, record):
if record.levelno == logging.INFO:
return self.is_rank_zero
return True
class ModelRunner: class ModelRunner:
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
...@@ -126,6 +139,10 @@ class ModelRunner: ...@@ -126,6 +139,10 @@ class ModelRunner:
self.mem_fraction_static = mem_fraction_static self.mem_fraction_static = mem_fraction_static
self.device = server_args.device self.device = server_args.device
self.gpu_id = gpu_id self.gpu_id = gpu_id
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.pp_rank = pp_rank self.pp_rank = pp_rank
...@@ -135,7 +152,6 @@ class ModelRunner: ...@@ -135,7 +152,6 @@ class ModelRunner:
self.is_draft_worker = is_draft_worker self.is_draft_worker = is_draft_worker
self.is_generation = model_config.is_generation self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal self.is_multimodal = model_config.is_multimodal
self.should_log = tp_rank == 0
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
...@@ -281,7 +297,6 @@ class ModelRunner: ...@@ -281,7 +297,6 @@ class ModelRunner:
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
if self.should_log:
logger.info( logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default." f"Attention backend not set. Use {server_args.attention_backend} backend by default."
) )
...@@ -294,7 +309,6 @@ class ModelRunner: ...@@ -294,7 +309,6 @@ class ModelRunner:
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
]: ]:
if self.should_log:
logger.info( logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend." f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
) )
...@@ -316,7 +330,6 @@ class ModelRunner: ...@@ -316,7 +330,6 @@ class ModelRunner:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
if server_args.enable_double_sparsity: if server_args.enable_double_sparsity:
if self.should_log:
logger.info( logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph." "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
) )
...@@ -330,25 +343,21 @@ class ModelRunner: ...@@ -330,25 +343,21 @@ class ModelRunner:
if self.is_multimodal: if self.is_multimodal:
self.mem_fraction_static *= 0.90 self.mem_fraction_static *= 0.90
if self.should_log:
logger.info( logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
f"because this is a multimodal model."
) )
server_args.chunked_prefill_size = -1
logger.info( logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model." "Automatically turn off --chunked-prefill-size for multimodal model."
) )
server_args.chunked_prefill_size = -1
if not self.use_mla_backend: if not self.use_mla_backend:
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1: elif self.page_size > 1:
if self.should_log:
logger.info("Disable chunked prefix cache when page size > 1.") logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache: if not server_args.disable_chunked_prefix_cache:
if self.should_log:
logger.info("Chunked prefix cache is turned on.") logger.info("Chunked prefix cache is turned on.")
def init_torch_distributed(self): def init_torch_distributed(self):
...@@ -446,7 +455,6 @@ class ModelRunner: ...@@ -446,7 +455,6 @@ class ModelRunner:
torch.set_num_threads(1) torch.set_num_threads(1)
if self.device == "cuda": if self.device == "cuda":
if torch.cuda.get_device_capability()[0] < 8: if torch.cuda.get_device_capability()[0] < 8:
if self.should_log:
logger.info( logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support." "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
) )
...@@ -485,7 +493,6 @@ class ModelRunner: ...@@ -485,7 +493,6 @@ class ModelRunner:
self.model.load_kv_cache_scales( self.model.load_kv_cache_scales(
self.server_args.quantization_param_path self.server_args.quantization_param_path
) )
if self.should_log:
logger.info( logger.info(
"Loaded KV cache scaling factors from %s", "Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path, self.server_args.quantization_param_path,
...@@ -1027,7 +1034,6 @@ class ModelRunner: ...@@ -1027,7 +1034,6 @@ class ModelRunner:
) )
def apply_torch_tp(self): def apply_torch_tp(self):
if self.should_log:
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel from sglang.srt.model_parallel import tensor_parallel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment