gpu_profiler.py 6.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
6
7
8
9
10
from abc import ABC, abstractmethod
from contextlib import nullcontext

import torch
from typing_extensions import override

import vllm.envs as envs
11
12
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)


16
class WorkerProfiler(ABC):
17
    def __init__(self) -> None:
18
19
20
21
22
23
        self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS
        if self._delay_iters > 0:
            logger.info_once(
                "GPU profiling will start "
                f"{self._delay_iters} steps after start_profile."
            )
24

25
26
27
28
29
30
31
        self._max_iters = envs.VLLM_PROFILER_MAX_ITERS
        if self._max_iters > 0:
            logger.info_once(
                "GPU profiling will stop "
                f"after {self._max_iters} worker steps, "
                "or when stop_profile is received."
            )
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        # Track when the profiler gets triggered by start_profile
        self._active_iteration_count = 0
        self._active = False

        # Track when the profiler is actually running
        self._profiling_for_iters = 0
        self._running = False

    @abstractmethod
    def _start(self) -> None:
        """Start the profiler."""
        pass

    @abstractmethod
    def _stop(self) -> None:
        """Stop the profiler."""
        pass

    def _call_start(self) -> None:
        """Call _start with error handling but no safeguards."""
53
        try:
54
55
            self._start()
            self._running = True  # Only mark as running if start succeeds
56
        except Exception as e:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            logger.warning("Failed to start profiler: %s", e)

    def _call_stop(self) -> None:
        """Call _stop with error handling but no safeguards."""
        try:
            self._stop()
            logger.info("Profiler stopped successfully.")
        except Exception as e:
            logger.warning("Failed to stop profiler: %s", e)
        self._running = False  # Always mark as not running, assume stop worked

    def start(self) -> None:
        """Attempt to start the profiler, accounting for delayed starts."""
        if self._active:
            logger.debug(
                "start_profile received when profiler is already active. "
                "Ignoring request."
            )
            return
        self._active = True
        if self._delay_iters == 0:
            self._call_start()

    def step(self) -> None:
        """Update the profiler state at each worker step,
        to handle delayed starts and max iteration limits."""
        if not self._active:
            return

        self._active_iteration_count += 1

        if (
            not self._running
            and self._delay_iters > 0
            and self._active_iteration_count == self._delay_iters
        ):
            logger.info("Starting profiler after delay...")
            self._call_start()

        if self._running:
            self._profiling_for_iters += 1

        if (
            self._max_iters > 0
            and self._running
            and self._profiling_for_iters > self._max_iters
        ):
            # Automatically stop the profiler after max iters
            # will be marked as not running, but leave as active so that stop
            # can clean up properly
            logger.info("Max profiling iterations reached. Stopping profiler...")
            self._call_stop()
            return
110
111

    def stop(self) -> None:
112
113
114
115
116
117
118
119
120
121
122
123
        """Attempt to stop the profiler, accounting for overlapped calls."""
        if not self._active:
            logger.debug(
                "stop_profile received when profiler is not active. Ignoring request."
            )
            return
        self._active = False
        self._active_iteration_count = 0
        self._profiling_for_iters = 0

        if self._running:
            self._call_stop()
124
125
126

    def shutdown(self) -> None:
        """Ensure profiler is stopped when shutting down."""
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        logger.info_once("Shutting down profiler")
        if self._running:
            self.stop()

    def annotate_context_manager(self, name: str):
        """Return a context manager to annotate profiler traces."""
        return nullcontext()


class TorchProfilerWrapper(WorkerProfiler):
    def __init__(self, worker_name: str, local_rank: int) -> None:
        super().__init__()

        self.local_rank = local_rank
        torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
        logger.info(
            "Torch profiling enabled. Traces will be saved to: %s",
            torch_profiler_trace_dir,
        )
        logger.debug(
            "Profiler config: record_shapes=%s,"
            "profile_memory=%s,with_stack=%s,with_flops=%s",
            envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
            envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
            envs.VLLM_TORCH_PROFILER_WITH_STACK,
            envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
        )
        self.profiler = torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
            profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
            with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
            with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
            ),
        )

    @override
    def _start(self) -> None:
        self.profiler.start()

    @override
    def _stop(self) -> None:
        self.profiler.stop()

        rank = self.local_rank
        profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
        profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
        sort_key = "self_cuda_time_total"
        table = self.profiler.key_averages().table(sort_by=sort_key)

        with open(profiler_out_file, "w") as f:
            print(table, file=f)

        # only print profiler results on rank 0
        if rank == 0:
            print(table)

    @override
    def annotate_context_manager(self, name: str):
        return torch.profiler.record_function(name)


class CudaProfilerWrapper(WorkerProfiler):
    def __init__(self) -> None:
        super().__init__()
        # Note: lazy import to avoid dependency issues if CUDA is not available.
        import torch.cuda.profiler as cuda_profiler

        self._cuda_profiler = cuda_profiler

    @override
    def _start(self) -> None:
        self._cuda_profiler.start()

    @override
    def _stop(self) -> None:
        self._cuda_profiler.stop()

    @override
    def annotate_context_manager(self, name: str):
        return torch.cuda.nvtx.range(name)