Unverified Commit ebe58d54 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Misc] Implement RankZeroFilter for rank-specific logging in model_runner.py (#6333)

parent 066cf445
......@@ -103,6 +103,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__)
class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
def __init__(self, is_rank_zero):
super().__init__()
self.is_rank_zero = is_rank_zero
def filter(self, record):
if record.levelno == logging.INFO:
return self.is_rank_zero
return True
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
......@@ -126,6 +139,10 @@ class ModelRunner:
self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank
self.tp_size = tp_size
self.pp_rank = pp_rank
......@@ -135,7 +152,6 @@ class ModelRunner:
self.is_draft_worker = is_draft_worker
self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal
self.should_log = tp_rank == 0
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
......@@ -281,10 +297,9 @@ class ModelRunner:
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
if self.should_log:
logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
)
logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
)
elif self.use_mla_backend:
if server_args.device != "cpu":
if server_args.attention_backend in [
......@@ -294,10 +309,9 @@ class ModelRunner:
"flashmla",
"cutlass_mla",
]:
if self.should_log:
logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
else:
raise ValueError(
f"Invalid attention backend for MLA: {server_args.attention_backend}"
......@@ -316,10 +330,9 @@ class ModelRunner:
server_args.attention_backend = "triton"
if server_args.enable_double_sparsity:
if self.should_log:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
server_args.attention_backend = "triton"
server_args.disable_cuda_graph = True
if server_args.ds_heavy_channel_type is None:
......@@ -330,26 +343,22 @@ class ModelRunner:
if self.is_multimodal:
self.mem_fraction_static *= 0.90
if self.should_log:
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
)
server_args.chunked_prefill_size = -1
logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
if not self.use_mla_backend:
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
if self.should_log:
logger.info("Disable chunked prefix cache when page size > 1.")
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache:
if self.should_log:
logger.info("Chunked prefix cache is turned on.")
logger.info("Chunked prefix cache is turned on.")
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
......@@ -446,10 +455,9 @@ class ModelRunner:
torch.set_num_threads(1)
if self.device == "cuda":
if torch.cuda.get_device_capability()[0] < 8:
if self.should_log:
logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16
if torch.cuda.get_device_capability()[1] < 5:
......@@ -485,11 +493,10 @@ class ModelRunner:
self.model.load_kv_cache_scales(
self.server_args.quantization_param_path
)
if self.should_log:
logger.info(
"Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path,
)
logger.info(
"Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path,
)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
......@@ -1027,8 +1034,7 @@ class ModelRunner:
)
def apply_torch_tp(self):
if self.should_log:
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment