wrapper.py 7.95 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from abc import ABC, abstractmethod
from contextlib import nullcontext
6
from typing import Literal
7
8
9
10

import torch
from typing_extensions import override

11
from vllm.config import ProfilerConfig
12
13
14
15
16
from vllm.logger import init_logger

logger = init_logger(__name__)


17
class WorkerProfiler(ABC):
18
19
    def __init__(self, profiler_config: ProfilerConfig) -> None:
        self._delay_iters = profiler_config.delay_iterations
20
21
22
23
24
        if self._delay_iters > 0:
            logger.info_once(
                "GPU profiling will start "
                f"{self._delay_iters} steps after start_profile."
            )
25

26
        self._max_iters = profiler_config.max_iterations
27
28
29
30
31
32
        if self._max_iters > 0:
            logger.info_once(
                "GPU profiling will stop "
                f"after {self._max_iters} worker steps, "
                "or when stop_profile is received."
            )
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        # Track when the profiler gets triggered by start_profile
        self._active_iteration_count = 0
        self._active = False

        # Track when the profiler is actually running
        self._profiling_for_iters = 0
        self._running = False

    @abstractmethod
    def _start(self) -> None:
        """Start the profiler."""
        pass

    @abstractmethod
    def _stop(self) -> None:
        """Stop the profiler."""
        pass

    def _call_start(self) -> None:
        """Call _start with error handling but no safeguards."""
54
        try:
55
56
            self._start()
            self._running = True  # Only mark as running if start succeeds
57
        except Exception as e:
58
59
60
61
62
63
            logger.warning("Failed to start profiler: %s", e)

    def _call_stop(self) -> None:
        """Call _stop with error handling but no safeguards."""
        try:
            self._stop()
64
            logger.info_once("Profiler stopped successfully.", scope="local")
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        except Exception as e:
            logger.warning("Failed to stop profiler: %s", e)
        self._running = False  # Always mark as not running, assume stop worked

    def start(self) -> None:
        """Attempt to start the profiler, accounting for delayed starts."""
        if self._active:
            logger.debug(
                "start_profile received when profiler is already active. "
                "Ignoring request."
            )
            return
        self._active = True
        if self._delay_iters == 0:
            self._call_start()

    def step(self) -> None:
        """Update the profiler state at each worker step,
        to handle delayed starts and max iteration limits."""
        if not self._active:
            return

        self._active_iteration_count += 1

        if (
            not self._running
            and self._delay_iters > 0
            and self._active_iteration_count == self._delay_iters
        ):
94
            logger.info_once("Starting profiler after delay...", scope="local")
95
96
97
98
99
100
101
102
103
104
105
106
107
            self._call_start()

        if self._running:
            self._profiling_for_iters += 1

        if (
            self._max_iters > 0
            and self._running
            and self._profiling_for_iters > self._max_iters
        ):
            # Automatically stop the profiler after max iters
            # will be marked as not running, but leave as active so that stop
            # can clean up properly
108
109
110
            logger.info_once(
                "Max profiling iterations reached. Stopping profiler...", scope="local"
            )
111
112
            self._call_stop()
            return
113
114

    def stop(self) -> None:
115
116
117
118
119
120
121
122
123
124
125
126
        """Attempt to stop the profiler, accounting for overlapped calls."""
        if not self._active:
            logger.debug(
                "stop_profile received when profiler is not active. Ignoring request."
            )
            return
        self._active = False
        self._active_iteration_count = 0
        self._profiling_for_iters = 0

        if self._running:
            self._call_stop()
127
128
129

    def shutdown(self) -> None:
        """Ensure profiler is stopped when shutting down."""
130
        logger.info_once("Shutting down profiler", scope="local")
131
132
133
134
135
136
137
138
        if self._running:
            self.stop()

    def annotate_context_manager(self, name: str):
        """Return a context manager to annotate profiler traces."""
        return nullcontext()


139
140
141
142
143
144
145
146
TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"]
TorchProfilerActivityMap = {
    "CPU": torch.profiler.ProfilerActivity.CPU,
    "CUDA": torch.profiler.ProfilerActivity.CUDA,
    "XPU": torch.profiler.ProfilerActivity.XPU,
}


147
class TorchProfilerWrapper(WorkerProfiler):
148
149
150
151
152
153
154
155
    def __init__(
        self,
        profiler_config: ProfilerConfig,
        worker_name: str,
        local_rank: int,
        activities: list[TorchProfilerActivity],
    ) -> None:
        super().__init__(profiler_config)
156
157

        self.local_rank = local_rank
158
159
        self.profiler_config = profiler_config
        torch_profiler_trace_dir = profiler_config.torch_profiler_dir
160
        if local_rank in (None, 0):
161
            logger.info_once(
162
163
                "Torch profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
164
                scope="local",
165
166
167
168
            )
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
169
170
171
172
                profiler_config.torch_profiler_record_shapes,
                profiler_config.torch_profiler_with_memory,
                profiler_config.torch_profiler_with_stack,
                profiler_config.torch_profiler_with_flops,
173
            )
174
175

        self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
176
        self.profiler = torch.profiler.profile(
177
178
179
180
181
            activities=[TorchProfilerActivityMap[activity] for activity in activities],
            record_shapes=profiler_config.torch_profiler_record_shapes,
            profile_memory=profiler_config.torch_profiler_with_memory,
            with_stack=profiler_config.torch_profiler_with_stack,
            with_flops=profiler_config.torch_profiler_with_flops,
182
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
183
184
                torch_profiler_trace_dir,
                worker_name=worker_name,
185
                use_gzip=profiler_config.torch_profiler_use_gzip,
186
187
188
189
190
191
192
193
194
195
196
            ),
        )

    @override
    def _start(self) -> None:
        self.profiler.start()

    @override
    def _stop(self) -> None:
        self.profiler.stop()

197
198
199
200
        profiler_config = self.profiler_config
        rank = self.local_rank
        if profiler_config.torch_profiler_dump_cuda_time_total:
            profiler_dir = profiler_config.torch_profiler_dir
201
202
203
            profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
            sort_key = "self_cuda_time_total"
            table = self.profiler.key_averages().table(sort_by=sort_key)
204

205
206
            with open(profiler_out_file, "w") as f:
                print(table, file=f)
207

208
209
210
            # only print profiler results on rank 0
            if rank == 0:
                print(table)
211
212
213
214
215
216
        if self.dump_cpu_time_total and rank == 0:
            logger.info(
                self.profiler.key_averages().table(
                    sort_by="self_cpu_time_total", row_limit=50
                )
            )
217
218
219
220
221
222
223

    @override
    def annotate_context_manager(self, name: str):
        return torch.profiler.record_function(name)


class CudaProfilerWrapper(WorkerProfiler):
224
225
    def __init__(self, profiler_config: ProfilerConfig) -> None:
        super().__init__(profiler_config)
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        # Note: lazy import to avoid dependency issues if CUDA is not available.
        import torch.cuda.profiler as cuda_profiler

        self._cuda_profiler = cuda_profiler

    @override
    def _start(self) -> None:
        self._cuda_profiler.start()

    @override
    def _stop(self) -> None:
        self._cuda_profiler.stop()

    @override
    def annotate_context_manager(self, name: str):
        return torch.cuda.nvtx.range(name)