wrapper.py 8.55 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from contextlib import nullcontext
7
from typing import Literal
8
9
10
11

import torch
from typing_extensions import override

12
from vllm.config import ProfilerConfig
13
from vllm.config.profiler import _is_uri_path
14
15
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)


19
class WorkerProfiler(ABC):
20
21
    def __init__(self, profiler_config: ProfilerConfig) -> None:
        self._delay_iters = profiler_config.delay_iterations
22
23
24
25
26
        if self._delay_iters > 0:
            logger.info_once(
                "GPU profiling will start "
                f"{self._delay_iters} steps after start_profile."
            )
27

28
        self._max_iters = profiler_config.max_iterations
29
30
31
32
33
34
        if self._max_iters > 0:
            logger.info_once(
                "GPU profiling will stop "
                f"after {self._max_iters} worker steps, "
                "or when stop_profile is received."
            )
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        # Track when the profiler gets triggered by start_profile
        self._active_iteration_count = 0
        self._active = False

        # Track when the profiler is actually running
        self._profiling_for_iters = 0
        self._running = False

    @abstractmethod
    def _start(self) -> None:
        """Start the profiler."""
        pass

    @abstractmethod
    def _stop(self) -> None:
        """Stop the profiler."""
        pass

    def _call_start(self) -> None:
        """Call _start with error handling but no safeguards."""
56
        try:
57
58
            self._start()
            self._running = True  # Only mark as running if start succeeds
59
        except Exception as e:
60
61
62
63
64
65
            logger.warning("Failed to start profiler: %s", e)

    def _call_stop(self) -> None:
        """Call _stop with error handling but no safeguards."""
        try:
            self._stop()
66
            logger.info_once("Profiler stopped successfully.", scope="local")
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        except Exception as e:
            logger.warning("Failed to stop profiler: %s", e)
        self._running = False  # Always mark as not running, assume stop worked

    def start(self) -> None:
        """Attempt to start the profiler, accounting for delayed starts."""
        if self._active:
            logger.debug(
                "start_profile received when profiler is already active. "
                "Ignoring request."
            )
            return
        self._active = True
        if self._delay_iters == 0:
            self._call_start()

    def step(self) -> None:
        """Update the profiler state at each worker step,
        to handle delayed starts and max iteration limits."""
        if not self._active:
            return

        self._active_iteration_count += 1

        if (
            not self._running
            and self._delay_iters > 0
            and self._active_iteration_count == self._delay_iters
        ):
96
            logger.info_once("Starting profiler after delay...", scope="local")
97
98
99
100
101
102
103
104
105
106
107
108
109
            self._call_start()

        if self._running:
            self._profiling_for_iters += 1

        if (
            self._max_iters > 0
            and self._running
            and self._profiling_for_iters > self._max_iters
        ):
            # Automatically stop the profiler after max iters
            # will be marked as not running, but leave as active so that stop
            # can clean up properly
110
111
112
            logger.info_once(
                "Max profiling iterations reached. Stopping profiler...", scope="local"
            )
113
114
            self._call_stop()
            return
115
116

    def stop(self) -> None:
117
118
119
120
121
122
123
124
125
126
127
128
        """Attempt to stop the profiler, accounting for overlapped calls."""
        if not self._active:
            logger.debug(
                "stop_profile received when profiler is not active. Ignoring request."
            )
            return
        self._active = False
        self._active_iteration_count = 0
        self._profiling_for_iters = 0

        if self._running:
            self._call_stop()
129
130
131

    def shutdown(self) -> None:
        """Ensure profiler is stopped when shutting down."""
132
        logger.info_once("Shutting down profiler", scope="local")
133
134
135
136
137
138
139
140
        if self._running:
            self.stop()

    def annotate_context_manager(self, name: str):
        """Return a context manager to annotate profiler traces."""
        return nullcontext()


141
142
143
144
145
146
147
148
TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"]
TorchProfilerActivityMap = {
    "CPU": torch.profiler.ProfilerActivity.CPU,
    "CUDA": torch.profiler.ProfilerActivity.CUDA,
    "XPU": torch.profiler.ProfilerActivity.XPU,
}


149
class TorchProfilerWrapper(WorkerProfiler):
150
151
152
153
154
155
    def __init__(
        self,
        profiler_config: ProfilerConfig,
        worker_name: str,
        local_rank: int,
        activities: list[TorchProfilerActivity],
156
        on_trace_ready: Callable[[torch.profiler.profile], None] | None = None,
157
158
    ) -> None:
        super().__init__(profiler_config)
159
160

        self.local_rank = local_rank
161
162
        self.profiler_config = profiler_config
        torch_profiler_trace_dir = profiler_config.torch_profiler_dir
163
        if local_rank in (None, 0):
164
            logger.info_once(
165
166
                "Torch profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
167
                scope="local",
168
169
170
171
            )
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
172
173
174
175
                profiler_config.torch_profiler_record_shapes,
                profiler_config.torch_profiler_with_memory,
                profiler_config.torch_profiler_with_stack,
                profiler_config.torch_profiler_with_flops,
176
            )
177

178
179
180
181
182
183
184
185
186
187
188
        # Determine trace handler: use custom handler if provided,
        # otherwise default to tensorboard trace handler
        if on_trace_ready is not None:
            trace_handler = on_trace_ready
        else:
            trace_handler = torch.profiler.tensorboard_trace_handler(
                torch_profiler_trace_dir,
                worker_name=worker_name,
                use_gzip=profiler_config.torch_profiler_use_gzip,
            )

189
        self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
190
        self.profiler = torch.profiler.profile(
191
192
193
194
195
            activities=[TorchProfilerActivityMap[activity] for activity in activities],
            record_shapes=profiler_config.torch_profiler_record_shapes,
            profile_memory=profiler_config.torch_profiler_with_memory,
            with_stack=profiler_config.torch_profiler_with_stack,
            with_flops=profiler_config.torch_profiler_with_flops,
196
            on_trace_ready=trace_handler,
197
198
199
200
201
202
203
204
205
206
        )

    @override
    def _start(self) -> None:
        self.profiler.start()

    @override
    def _stop(self) -> None:
        self.profiler.stop()

207
208
209
210
        profiler_config = self.profiler_config
        rank = self.local_rank
        if profiler_config.torch_profiler_dump_cuda_time_total:
            profiler_dir = profiler_config.torch_profiler_dir
211
212
            sort_key = "self_cuda_time_total"
            table = self.profiler.key_averages().table(sort_by=sort_key)
213

214
215
216
217
218
219
            # Skip file write for URI paths (gs://, s3://, etc.)
            # as standard file I/O doesn't work with URI schemes
            if not _is_uri_path(profiler_dir):
                profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
                with open(profiler_out_file, "w") as f:
                    print(table, file=f)
220

221
222
223
            # only print profiler results on rank 0
            if rank == 0:
                print(table)
224
225
226
227
228
229
        if self.dump_cpu_time_total and rank == 0:
            logger.info(
                self.profiler.key_averages().table(
                    sort_by="self_cpu_time_total", row_limit=50
                )
            )
230
231
232
233
234
235
236

    @override
    def annotate_context_manager(self, name: str):
        return torch.profiler.record_function(name)


class CudaProfilerWrapper(WorkerProfiler):
237
238
    def __init__(self, profiler_config: ProfilerConfig) -> None:
        super().__init__(profiler_config)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        # Note: lazy import to avoid dependency issues if CUDA is not available.
        import torch.cuda.profiler as cuda_profiler

        self._cuda_profiler = cuda_profiler

    @override
    def _start(self) -> None:
        self._cuda_profiler.start()

    @override
    def _stop(self) -> None:
        self._cuda_profiler.stop()

    @override
    def annotate_context_manager(self, name: str):
        return torch.cuda.nvtx.range(name)