Unverified Commit ebe58d54 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Misc] Implement RankZeroFilter for rank-specific logging in model_runner.py (#6333)

parent 066cf445
...@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 ...@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
def __init__(self, is_rank_zero):
super().__init__()
self.is_rank_zero = is_rank_zero
def filter(self, record):
if record.levelno == logging.INFO:
return self.is_rank_zero
return True
class ModelRunner: class ModelRunner:
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
...@@ -126,6 +139,10 @@ class ModelRunner: ...@@ -126,6 +139,10 @@ class ModelRunner:
self.mem_fraction_static = mem_fraction_static self.mem_fraction_static = mem_fraction_static
self.device = server_args.device self.device = server_args.device
self.gpu_id = gpu_id self.gpu_id = gpu_id
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.pp_rank = pp_rank self.pp_rank = pp_rank
...@@ -135,7 +152,6 @@ class ModelRunner: ...@@ -135,7 +152,6 @@ class ModelRunner:
self.is_draft_worker = is_draft_worker self.is_draft_worker = is_draft_worker
self.is_generation = model_config.is_generation self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal self.is_multimodal = model_config.is_multimodal
self.should_log = tp_rank == 0
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
...@@ -281,10 +297,9 @@ class ModelRunner: ...@@ -281,10 +297,9 @@ class ModelRunner:
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
if self.should_log: logger.info(
logger.info( f"Attention backend not set. Use {server_args.attention_backend} backend by default."
f"Attention backend not set. Use {server_args.attention_backend} backend by default." )
)
elif self.use_mla_backend: elif self.use_mla_backend:
if server_args.device != "cpu": if server_args.device != "cpu":
if server_args.attention_backend in [ if server_args.attention_backend in [
...@@ -294,10 +309,9 @@ class ModelRunner: ...@@ -294,10 +309,9 @@ class ModelRunner:
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
]: ]:
if self.should_log: logger.info(
logger.info( f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
f"MLA optimization is turned on. Use {server_args.attention_backend} backend." )
)
else: else:
raise ValueError( raise ValueError(
f"Invalid attention backend for MLA: {server_args.attention_backend}" f"Invalid attention backend for MLA: {server_args.attention_backend}"
...@@ -316,10 +330,9 @@ class ModelRunner: ...@@ -316,10 +330,9 @@ class ModelRunner:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
if server_args.enable_double_sparsity: if server_args.enable_double_sparsity:
if self.should_log: logger.info(
logger.info( "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
"Double sparsity optimization is turned on. Use triton backend without CUDA graph." )
)
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
if server_args.ds_heavy_channel_type is None: if server_args.ds_heavy_channel_type is None:
...@@ -330,26 +343,22 @@ class ModelRunner: ...@@ -330,26 +343,22 @@ class ModelRunner:
if self.is_multimodal: if self.is_multimodal:
self.mem_fraction_static *= 0.90 self.mem_fraction_static *= 0.90
if self.should_log: logger.info(
logger.info( f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " )
f"because this is a multimodal model."
)
logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
server_args.chunked_prefill_size = -1 server_args.chunked_prefill_size = -1
logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
if not self.use_mla_backend: if not self.use_mla_backend:
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1: elif self.page_size > 1:
if self.should_log: logger.info("Disable chunked prefix cache when page size > 1.")
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache: if not server_args.disable_chunked_prefix_cache:
if self.should_log: logger.info("Chunked prefix cache is turned on.")
logger.info("Chunked prefix cache is turned on.")
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
...@@ -446,10 +455,9 @@ class ModelRunner: ...@@ -446,10 +455,9 @@ class ModelRunner:
torch.set_num_threads(1) torch.set_num_threads(1)
if self.device == "cuda": if self.device == "cuda":
if torch.cuda.get_device_capability()[0] < 8: if torch.cuda.get_device_capability()[0] < 8:
if self.should_log: logger.info(
logger.info( "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support." )
)
self.server_args.dtype = "float16" self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16 self.model_config.dtype = torch.float16
if torch.cuda.get_device_capability()[1] < 5: if torch.cuda.get_device_capability()[1] < 5:
...@@ -485,11 +493,10 @@ class ModelRunner: ...@@ -485,11 +493,10 @@ class ModelRunner:
self.model.load_kv_cache_scales( self.model.load_kv_cache_scales(
self.server_args.quantization_param_path self.server_args.quantization_param_path
) )
if self.should_log: logger.info(
logger.info( "Loaded KV cache scaling factors from %s",
"Loaded KV cache scaling factors from %s", self.server_args.quantization_param_path,
self.server_args.quantization_param_path, )
)
else: else:
raise RuntimeError( raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but " "Using FP8 KV cache and scaling factors provided but "
...@@ -1027,8 +1034,7 @@ class ModelRunner: ...@@ -1027,8 +1034,7 @@ class ModelRunner:
) )
def apply_torch_tp(self): def apply_torch_tp(self):
if self.should_log: logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment