Unverified Commit 0cdd2136 authored by 22quinn's avatar 22quinn Committed by GitHub
Browse files

[Misc] Improve Worker process title and logging prefix (#22205)


Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 948dd344
......@@ -3359,7 +3359,7 @@ def has_triton_kernels() -> bool:
def set_process_title(name: str,
suffix: str = "",
append: bool = False) -> None:
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None:
"""
Set the current process title to a specific name with an
optional suffix.
......@@ -3367,15 +3367,11 @@ def set_process_title(name: str,
Args:
name: The title to assign to the current process.
suffix: An optional suffix to append to the base name.
append: Whether to append to the existing process title.
prefix: A prefix to prepend to the front separated by `::`.
"""
if suffix:
name = f"{name}_{suffix}"
if append:
name = f"{setproctitle.getproctitle()}_{name}"
else:
name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}"
setproctitle.setproctitle(name)
setproctitle.setproctitle(f"{prefix}::{name}")
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
......
......@@ -697,7 +697,7 @@ class EngineCoreProc(EngineCore):
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
set_process_title("DPEngineCore", str(dp_rank))
set_process_title("EngineCore", f"DP{dp_rank}")
decorate_logs()
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
......
......@@ -116,7 +116,7 @@ class CoreEngineProcManager:
local_dp_ranks.append(local_index)
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
name=f"EngineCore_DP{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
......
......@@ -26,6 +26,8 @@ from vllm.distributed import (destroy_distributed_environment,
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_pp_group, get_tp_group)
from vllm.executor.multiproc_worker_utils import (
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
......@@ -397,17 +399,6 @@ class WorkerProc:
wrapper.init_worker(all_kwargs)
self.worker = wrapper
pp_size = vllm_config.parallel_config.pipeline_parallel_size
tp_size = vllm_config.parallel_config.tensor_parallel_size
pp_str = f"PP{rank // tp_size}" if pp_size > 1 else ""
tp_str = f"TP{rank % tp_size}" if tp_size > 1 else ""
suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}"
process_name = "VllmWorker"
if suffix:
set_process_title(suffix, append=True)
process_name = f"{process_name} {suffix}"
decorate_logs(process_name)
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
......@@ -425,8 +416,14 @@ class WorkerProc:
name="WorkerAsyncOutputCopy")
self.async_output_copy_thread.start()
# Initialize device and loads weights
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel)
# Load model
self.worker.load_model()
@staticmethod
......@@ -663,3 +660,24 @@ class WorkerProc:
if output_rank is None or self.rank == output_rank:
self.handle_output(output)
@staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size
pp_rank = get_pp_group().rank_in_group
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
process_name = "Worker"
if dp_size > 1:
process_name += f"_DP{dp_rank}"
if pp_size > 1:
process_name += f"_PP{pp_rank}"
if tp_size > 1:
process_name += f"_TP{tp_rank}"
if enable_ep:
ep_rank = get_ep_group().rank_in_group
process_name += f"_EP{ep_rank}"
set_process_title(name=process_name)
decorate_logs(process_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment