profiler.py 7.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import Any, Literal

from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash

logger = init_logger(__name__)

ProfilerKind = Literal["torch", "cuda"]


21
22
23
24
25
26
27
28
29
30
31
32
33
34
def _is_uri_path(path: str) -> bool:
    """Check if path is a URI (scheme://...), excluding Windows drive letters.

    Supports custom URI schemes like gs://, s3://, hdfs://, etc.
    These paths should not be converted to absolute paths.
    """
    if "://" in path:
        scheme = path.split("://")[0]
        # Windows drive letters are single characters (e.g., C://)
        # Valid URI schemes have more than one character
        return len(scheme) > 1
    return False


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@config
@dataclass
class ProfilerConfig:
    """Dataclass which contains profiler config for the engine."""

    profiler: ProfilerKind | None = None
    """Which profiler to use. Defaults to None. Options are:

    - 'torch': Use PyTorch profiler.\n
    - 'cuda': Use CUDA profiler."""

    torch_profiler_dir: str = ""
    """Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
    worker's traces (CPU & GPU) will be saved under this directory. Note that
    it must be an absolute path."""

    torch_profiler_with_stack: bool = True
    """If `True`, enables stack tracing in the torch profiler. Enabled by default."""

    torch_profiler_with_flops: bool = False
    """If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""

    torch_profiler_use_gzip: bool = True
    """If `True`, saves torch profiler traces in gzip format. Enabled by default"""

    torch_profiler_dump_cuda_time_total: bool = True
    """If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""

    torch_profiler_record_shapes: bool = False
    """If `True`, records tensor shapes in the torch profiler. Disabled by default."""

    torch_profiler_with_memory: bool = False
    """If `True`, enables memory profiling in the torch profiler.
    Disabled by default."""

    ignore_frontend: bool = False
71
    """If `True`, disables the front-end profiling of AsyncLLM when using the
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    'torch' profiler. This is needed to reduce overhead when using delay/limit options,
    since the front-end profiling does not track iterations and will capture the
    entire range.
    """

    delay_iterations: int = Field(default=0, ge=0)
    """Number of engine iterations to skip before starting profiling.
    Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
    """

    max_iterations: int = Field(default=0, ge=0)
    """Maximum number of engine iterations to profile after starting profiling.
    Defaults to 0, meaning no limit.
    """

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: list[Any] = []
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
        return hash_str

    def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
        """Get field from env var if set, with deprecation warning."""

        if envs.is_set(env_var_name):
            value = getattr(envs, env_var_name)
            logger.warning_once(
                "Using %s environment variable is deprecated and will be removed in "
                "v0.14.0 or v1.0.0, whichever is soonest. Please use "
                "--profiler-config.%s command line argument or "
                "ProfilerConfig(%s=...) config field instead.",
                env_var_name,
                field_name,
                field_name,
            )
            return value
        return None

    def _set_from_env_if_set(
        self,
        field_name: str,
        env_var_name: str,
        to_bool: bool = True,
        to_int: bool = False,
    ) -> None:
        """Set field from env var if set, with deprecation warning."""
        value = self._get_from_env_if_set(field_name, env_var_name)
        if value is not None:
            if to_bool:
                value = value == "1"
            if to_int:
                value = int(value)
            setattr(self, field_name, value)

    @model_validator(mode="after")
    def _validate_profiler_config(self) -> Self:
        maybe_use_cuda_profiler = self._get_from_env_if_set(
            "profiler", "VLLM_TORCH_CUDA_PROFILE"
        )
        if maybe_use_cuda_profiler is not None:
            self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
        else:
            self._set_from_env_if_set(
                "torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
            )
            if self.torch_profiler_dir:
                self.profiler = "torch"
                self._set_from_env_if_set(
                    "torch_profiler_record_shapes",
                    "VLLM_TORCH_PROFILER_RECORD_SHAPES",
                )
                self._set_from_env_if_set(
                    "torch_profiler_with_memory",
                    "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
                )
                self._set_from_env_if_set(
                    "torch_profiler_with_stack",
                    "VLLM_TORCH_PROFILER_WITH_STACK",
                )
                self._set_from_env_if_set(
                    "torch_profiler_with_flops",
                    "VLLM_TORCH_PROFILER_WITH_FLOPS",
                )
                self._set_from_env_if_set(
                    "ignore_frontend",
                    "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
                )
                self._set_from_env_if_set(
                    "torch_profiler_use_gzip",
                    "VLLM_TORCH_PROFILER_USE_GZIP",
                )
                self._set_from_env_if_set(
                    "torch_profiler_dump_cuda_time_total",
                    "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
                )

        self._set_from_env_if_set(
            "delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
        )
        self._set_from_env_if_set(
            "max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
        )

        has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
        if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
            logger.warning_once(
                "Using 'torch' profiler with delay_iterations or max_iterations "
                "while ignore_frontend is False may result in high overhead."
            )

        profiler_dir = self.torch_profiler_dir
        if profiler_dir and self.profiler != "torch":
            raise ValueError(
                "torch_profiler_dir is only applicable when profiler is set to 'torch'"
            )
        if self.profiler == "torch" and not profiler_dir:
            raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")

202
203
204
205
        # Support any URI scheme (gs://, s3://, hdfs://, etc.)
        # These paths should not be converted to absolute paths
        if profiler_dir and not _is_uri_path(profiler_dir):
            self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
206
207

        return self