wrapper.py 11.4 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from contextlib import nullcontext
7
from typing import Literal
8
9
10
11

import torch
from typing_extensions import override

12
from vllm.config import ProfilerConfig
13
from vllm.config.profiler import _is_uri_path
14
15
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)


19
class WorkerProfiler(ABC):
20
21
    def __init__(self, profiler_config: ProfilerConfig) -> None:
        self._delay_iters = profiler_config.delay_iterations
22
23
24
25
26
        if self._delay_iters > 0:
            logger.info_once(
                "GPU profiling will start "
                f"{self._delay_iters} steps after start_profile."
            )
27

28
        self._max_iters = profiler_config.max_iterations
29
30
31
32
33
34
        if self._max_iters > 0:
            logger.info_once(
                "GPU profiling will stop "
                f"after {self._max_iters} worker steps, "
                "or when stop_profile is received."
            )
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        # Track when the profiler gets triggered by start_profile
        self._active_iteration_count = 0
        self._active = False

        # Track when the profiler is actually running
        self._profiling_for_iters = 0
        self._running = False

    @abstractmethod
    def _start(self) -> None:
        """Start the profiler."""
        pass

    @abstractmethod
    def _stop(self) -> None:
        """Stop the profiler."""
        pass

    def _call_start(self) -> None:
        """Call _start with error handling but no safeguards."""
56
        try:
57
58
            self._start()
            self._running = True  # Only mark as running if start succeeds
59
        except Exception as e:
60
61
62
63
64
65
            logger.warning("Failed to start profiler: %s", e)

    def _call_stop(self) -> None:
        """Call _stop with error handling but no safeguards."""
        try:
            self._stop()
66
            logger.info_once("Profiler stopped successfully.")
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        except Exception as e:
            logger.warning("Failed to stop profiler: %s", e)
        self._running = False  # Always mark as not running, assume stop worked

    def start(self) -> None:
        """Attempt to start the profiler, accounting for delayed starts."""
        if self._active:
            logger.debug(
                "start_profile received when profiler is already active. "
                "Ignoring request."
            )
            return
        self._active = True
        if self._delay_iters == 0:
            self._call_start()

    def step(self) -> None:
        """Update the profiler state at each worker step,
        to handle delayed starts and max iteration limits."""
        if not self._active:
            return

        self._active_iteration_count += 1

        if (
            not self._running
            and self._delay_iters > 0
            and self._active_iteration_count == self._delay_iters
        ):
96
            logger.info_once("Starting profiler after delay...")
97
98
            self._call_start()

99
100
101
        # Call profiler step for schedule-based profiling
        # Only count iterations where data is actually recorded (not warmup)
        if self._running and self._profiler_step():
102
103
104
105
106
107
108
109
110
111
            self._profiling_for_iters += 1

        if (
            self._max_iters > 0
            and self._running
            and self._profiling_for_iters > self._max_iters
        ):
            # Automatically stop the profiler after max iters
            # will be marked as not running, but leave as active so that stop
            # can clean up properly
112
            logger.info_once("Max profiling iterations reached. Stopping profiler...")
113
114
            self._call_stop()
            return
115

116
117
118
119
120
121
122
123
124
125
    def _profiler_step(self) -> bool:
        """Called each step when profiler is running.
        Override in subclasses to handle schedule-based profiling.

        Returns:
            True if the step was an active profiling step (data recorded),
            False if the step was a warmup step (data discarded).
        """
        return True

126
    def stop(self) -> None:
127
128
129
130
131
132
133
134
135
136
137
138
        """Attempt to stop the profiler, accounting for overlapped calls."""
        if not self._active:
            logger.debug(
                "stop_profile received when profiler is not active. Ignoring request."
            )
            return
        self._active = False
        self._active_iteration_count = 0
        self._profiling_for_iters = 0

        if self._running:
            self._call_stop()
139
140
141

    def shutdown(self) -> None:
        """Ensure profiler is stopped when shutting down."""
142
        logger.info_once("Shutting down profiler")
143
144
145
146
147
148
149
150
        if self._running:
            self.stop()

    def annotate_context_manager(self, name: str):
        """Return a context manager to annotate profiler traces."""
        return nullcontext()


151
152
153
154
155
156
157
158
TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"]
TorchProfilerActivityMap = {
    "CPU": torch.profiler.ProfilerActivity.CPU,
    "CUDA": torch.profiler.ProfilerActivity.CUDA,
    "XPU": torch.profiler.ProfilerActivity.XPU,
}


159
class TorchProfilerWrapper(WorkerProfiler):
160
161
162
163
164
165
    def __init__(
        self,
        profiler_config: ProfilerConfig,
        worker_name: str,
        local_rank: int,
        activities: list[TorchProfilerActivity],
166
        on_trace_ready: Callable[[torch.profiler.profile], None] | None = None,
167
168
    ) -> None:
        super().__init__(profiler_config)
169
170

        self.local_rank = local_rank
171
172
        self.profiler_config = profiler_config
        torch_profiler_trace_dir = profiler_config.torch_profiler_dir
173
        if local_rank in (None, 0):
174
            logger.info_once(
175
176
177
178
179
180
                "Torch profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
181
182
183
184
                profiler_config.torch_profiler_record_shapes,
                profiler_config.torch_profiler_with_memory,
                profiler_config.torch_profiler_with_stack,
                profiler_config.torch_profiler_with_flops,
185
            )
186

187
188
189
190
191
192
193
194
195
196
197
        # Determine trace handler: use custom handler if provided,
        # otherwise default to tensorboard trace handler
        if on_trace_ready is not None:
            trace_handler = on_trace_ready
        else:
            trace_handler = torch.profiler.tensorboard_trace_handler(
                torch_profiler_trace_dir,
                worker_name=worker_name,
                use_gzip=profiler_config.torch_profiler_use_gzip,
            )

198
        self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        # Create profiler schedule if warmup or wait iterations are configured
        profiler_schedule = None
        if profiler_config.warmup_iterations > 0 or profiler_config.wait_iterations > 0:
            profiler_schedule = torch.profiler.schedule(
                skip_first=0,
                wait=profiler_config.wait_iterations,
                warmup=profiler_config.warmup_iterations,
                active=profiler_config.active_iterations,
                repeat=1,
            )
            if local_rank in (None, 0):
                logger.info_once(
                    "Profiler schedule configured: wait=%d, warmup=%d, active=%d",
                    profiler_config.wait_iterations,
                    profiler_config.warmup_iterations,
                    profiler_config.active_iterations,
                )

218
        self.profiler = torch.profiler.profile(
219
            activities=[TorchProfilerActivityMap[activity] for activity in activities],
220
            schedule=profiler_schedule,
221
222
223
224
            record_shapes=profiler_config.torch_profiler_record_shapes,
            profile_memory=profiler_config.torch_profiler_with_memory,
            with_stack=profiler_config.torch_profiler_with_stack,
            with_flops=profiler_config.torch_profiler_with_flops,
225
            on_trace_ready=trace_handler,
226
227
        )

228
229
230
231
232
233
234
235
236
237
238
        # Track if we're using a schedule (need to call step())
        self._uses_schedule = profiler_schedule is not None
        self._warmup_iterations = profiler_config.warmup_iterations
        # Subtract 1 because profiler.start() already consumes step 0
        # (WAIT or WARMUP), so only wait + warmup - 1 non-active steps
        # remain to be advanced through via profiler.step() calls.
        self._warmup_steps_remaining = max(
            profiler_config.wait_iterations + profiler_config.warmup_iterations - 1,
            0,
        )

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    def _build_profiler_table(
        self,
        sort_key: str,
        row_limit: int | None = None,
    ) -> str:
        if row_limit is None:  # use profiler default row limit of 100
            return self.profiler.key_averages().table(sort_by=sort_key)
        return self.profiler.key_averages().table(
            sort_by=sort_key,
            row_limit=row_limit,
        )

    def _write_profiler_table(self, rank: int, table: str) -> None:
        profiler_dir = self.profiler_config.torch_profiler_dir

        # Skip file write for URI paths (gs://, s3://, etc.)
        # as standard file I/O doesn't work with URI schemes
        if not _is_uri_path(profiler_dir):
            profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
            with open(profiler_out_file, "w") as f:
                print(table, file=f)

261
262
263
264
265
266
267
268
    @override
    def _start(self) -> None:
        self.profiler.start()

    @override
    def _stop(self) -> None:
        self.profiler.stop()

269
270
271
        profiler_config = self.profiler_config
        rank = self.local_rank
        if profiler_config.torch_profiler_dump_cuda_time_total:
272
273
            table = self._build_profiler_table(sort_key="self_cuda_time_total")
            self._write_profiler_table(rank, table)
274

275
276
277
            # only print profiler results on rank 0
            if rank == 0:
                print(table)
278
279
280
281

        if self.dump_cpu_time_total:
            table = self._build_profiler_table(
                sort_key="self_cpu_time_total", row_limit=50
282
            )
283
284
285
286
287
            self._write_profiler_table(rank, table)

            # only print profiler results on rank 0
            if rank == 0:
                print(table)
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    @override
    def _profiler_step(self) -> bool:
        """Call profiler.step() when using schedule-based profiling.

        Returns:
            True if the step was an active profiling step (data recorded),
            False if the step was a warmup step (data discarded).
        """
        if self._uses_schedule:
            self.profiler.step()
            # Track warmup steps - only count active steps toward max_iterations
            if self._warmup_steps_remaining > 0:
                self._warmup_steps_remaining -= 1
                return False
        return True

305
306
307
308
309
310
    @override
    def annotate_context_manager(self, name: str):
        return torch.profiler.record_function(name)


class CudaProfilerWrapper(WorkerProfiler):
311
312
    def __init__(self, profiler_config: ProfilerConfig) -> None:
        super().__init__(profiler_config)
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        # Note: lazy import to avoid dependency issues if CUDA is not available.
        import torch.cuda.profiler as cuda_profiler

        self._cuda_profiler = cuda_profiler

    @override
    def _start(self) -> None:
        self._cuda_profiler.start()

    @override
    def _stop(self) -> None:
        self._cuda_profiler.stop()

    @override
    def annotate_context_manager(self, name: str):
        return torch.cuda.nvtx.range(name)