Unverified Commit 7eca8591 authored by fenypatel99's avatar fenypatel99 Committed by GitHub
Browse files

Add PyTorch profiler schedule support with warmup/active iterations (#35240)

parent 636ee223
...@@ -45,8 +45,10 @@ class ProfilerConfig: ...@@ -45,8 +45,10 @@ class ProfilerConfig:
worker's traces (CPU & GPU) will be saved under this directory. Note that worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path.""" it must be an absolute path."""
torch_profiler_with_stack: bool = True torch_profiler_with_stack: bool = False
"""If `True`, enables stack tracing in the torch profiler. Enabled by default.""" """If `True`, enables stack tracing in the torch profiler. Disabled by default
to reduce overhead. Can be enabled via VLLM_TORCH_PROFILER_WITH_STACK=1 env var
or --profiler-config.torch_profiler_with_stack=true CLI flag."""
torch_profiler_with_flops: bool = False torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default.""" """If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
...@@ -81,6 +83,27 @@ class ProfilerConfig: ...@@ -81,6 +83,27 @@ class ProfilerConfig:
Defaults to 0, meaning no limit. Defaults to 0, meaning no limit.
""" """
warmup_iterations: int = Field(default=0, ge=0)
"""Number of warmup iterations for PyTorch profiler schedule.
During warmup, the profiler runs but data is discarded. This helps reduce
noise from JIT compilation and other one-time costs in the profiled trace.
Defaults to 0 (schedule-based profiling disabled, recording all iterations).
Set to a positive value (e.g., 2) to enable schedule-based profiling.
"""
active_iterations: int = Field(default=5, ge=1)
"""Number of active iterations for PyTorch profiler schedule.
This is the number of iterations where profiling data is actually collected.
Defaults to 5 active iterations.
"""
wait_iterations: int = Field(default=0, ge=0)
"""Number of wait iterations for PyTorch profiler schedule.
During wait, the profiler is completely off with zero overhead.
This allows skipping initial iterations before warmup begins.
Defaults to 0 (no wait period).
"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
......
...@@ -96,7 +96,9 @@ class WorkerProfiler(ABC): ...@@ -96,7 +96,9 @@ class WorkerProfiler(ABC):
logger.info_once("Starting profiler after delay...", scope="local") logger.info_once("Starting profiler after delay...", scope="local")
self._call_start() self._call_start()
if self._running: # Call profiler step for schedule-based profiling
# Only count iterations where data is actually recorded (not warmup)
if self._running and self._profiler_step():
self._profiling_for_iters += 1 self._profiling_for_iters += 1
if ( if (
...@@ -113,6 +115,16 @@ class WorkerProfiler(ABC): ...@@ -113,6 +115,16 @@ class WorkerProfiler(ABC):
self._call_stop() self._call_stop()
return return
def _profiler_step(self) -> bool:
"""Called each step when profiler is running.
Override in subclasses to handle schedule-based profiling.
Returns:
True if the step was an active profiling step (data recorded),
False if the step was a warmup step (data discarded).
"""
return True
def stop(self) -> None: def stop(self) -> None:
"""Attempt to stop the profiler, accounting for overlapped calls.""" """Attempt to stop the profiler, accounting for overlapped calls."""
if not self._active: if not self._active:
...@@ -187,8 +199,29 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -187,8 +199,29 @@ class TorchProfilerWrapper(WorkerProfiler):
) )
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1 self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
# Create profiler schedule if warmup or wait iterations are configured
profiler_schedule = None
if profiler_config.warmup_iterations > 0 or profiler_config.wait_iterations > 0:
profiler_schedule = torch.profiler.schedule(
skip_first=0,
wait=profiler_config.wait_iterations,
warmup=profiler_config.warmup_iterations,
active=profiler_config.active_iterations,
repeat=1,
)
if local_rank in (None, 0):
logger.info_once(
"Profiler schedule configured: wait=%d, warmup=%d, active=%d",
profiler_config.wait_iterations,
profiler_config.warmup_iterations,
profiler_config.active_iterations,
scope="local",
)
self.profiler = torch.profiler.profile( self.profiler = torch.profiler.profile(
activities=[TorchProfilerActivityMap[activity] for activity in activities], activities=[TorchProfilerActivityMap[activity] for activity in activities],
schedule=profiler_schedule,
record_shapes=profiler_config.torch_profiler_record_shapes, record_shapes=profiler_config.torch_profiler_record_shapes,
profile_memory=profiler_config.torch_profiler_with_memory, profile_memory=profiler_config.torch_profiler_with_memory,
with_stack=profiler_config.torch_profiler_with_stack, with_stack=profiler_config.torch_profiler_with_stack,
...@@ -196,6 +229,17 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -196,6 +229,17 @@ class TorchProfilerWrapper(WorkerProfiler):
on_trace_ready=trace_handler, on_trace_ready=trace_handler,
) )
# Track if we're using a schedule (need to call step())
self._uses_schedule = profiler_schedule is not None
self._warmup_iterations = profiler_config.warmup_iterations
# Subtract 1 because profiler.start() already consumes step 0
# (WAIT or WARMUP), so only wait + warmup - 1 non-active steps
# remain to be advanced through via profiler.step() calls.
self._warmup_steps_remaining = max(
profiler_config.wait_iterations + profiler_config.warmup_iterations - 1,
0,
)
@override @override
def _start(self) -> None: def _start(self) -> None:
self.profiler.start() self.profiler.start()
...@@ -228,6 +272,22 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -228,6 +272,22 @@ class TorchProfilerWrapper(WorkerProfiler):
) )
) )
@override
def _profiler_step(self) -> bool:
"""Call profiler.step() when using schedule-based profiling.
Returns:
True if the step was an active profiling step (data recorded),
False if the step was a warmup step (data discarded).
"""
if self._uses_schedule:
self.profiler.step()
# Track warmup steps - only count active steps toward max_iterations
if self._warmup_steps_remaining > 0:
self._warmup_steps_remaining -= 1
return False
return True
@override @override
def annotate_context_manager(self, name: str): def annotate_context_manager(self, name: str):
return torch.profiler.record_function(name) return torch.profiler.record_function(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment