logger.py 9.64 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Logging configuration for vLLM."""
4

5
import datetime
6
import json
Woosuk Kwon's avatar
Woosuk Kwon committed
7
import logging
8
import os
9
import sys
10
from collections.abc import Hashable
11
from functools import lru_cache, partial
12
13
14
from logging import Logger
from logging.config import dictConfig
from os import path
15
from types import MethodType
16
from typing import Any, Literal, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
19
import vllm.envs as envs

20
_FORMAT = (
Nick Hill's avatar
Nick Hill committed
21
    f"{envs.VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
22
23
    "[%(fileinfo)s:%(lineno)d] %(message)s"
)
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
_DATE_FORMAT = "%m-%d %H:%M:%S"

Nick Hill's avatar
Nick Hill committed
26
27
28
29
30
31
32
33
34
35
36
37
38

def _use_color() -> bool:
    if envs.NO_COLOR or envs.VLLM_LOGGING_COLOR == "0":
        return False
    if envs.VLLM_LOGGING_COLOR == "1":
        return True
    if envs.VLLM_LOGGING_STREAM == "ext://sys.stdout":  # stdout
        return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
    elif envs.VLLM_LOGGING_STREAM == "ext://sys.stderr":  # stderr
        return hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
    return False


39
40
41
DEFAULT_LOGGING_CONFIG = {
    "formatters": {
        "vllm": {
42
            "class": "vllm.logging_utils.NewLineFormatter",
43
44
45
            "datefmt": _DATE_FORMAT,
            "format": _FORMAT,
        },
Nick Hill's avatar
Nick Hill committed
46
47
48
49
50
        "vllm_color": {
            "class": "vllm.logging_utils.ColoredFormatter",
            "datefmt": _DATE_FORMAT,
            "format": _FORMAT,
        },
51
52
53
54
    },
    "handlers": {
        "vllm": {
            "class": "logging.StreamHandler",
Nick Hill's avatar
Nick Hill committed
55
56
57
58
            # Choose formatter based on color setting.
            "formatter": "vllm_color" if _use_color() else "vllm",
            "level": envs.VLLM_LOGGING_LEVEL,
            "stream": envs.VLLM_LOGGING_STREAM,
59
60
61
62
63
64
65
66
67
68
        },
    },
    "loggers": {
        "vllm": {
            "handlers": ["vllm"],
            "level": "DEBUG",
            "propagate": False,
        },
    },
    "version": 1,
69
    "disable_existing_loggers": False,
70
71
72
}


73
74
@lru_cache
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
75
76
    # Set the stacklevel to 3 to print the original caller's line info
    logger.debug(msg, *args, stacklevel=3)
77
78


79
@lru_cache
80
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
81
82
    # Set the stacklevel to 3 to print the original caller's line info
    logger.info(msg, *args, stacklevel=3)
83
84
85


@lru_cache
86
def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    # Set the stacklevel to 3 to print the original caller's line info
    logger.warning(msg, *args, stacklevel=3)


LogScope = Literal["process", "global", "local"]


def _should_log_with_scope(scope: LogScope) -> bool:
    """Decide whether to log based on scope"""
    if scope == "global":
        from vllm.distributed.parallel_state import is_global_first_rank

        return is_global_first_rank()
    if scope == "local":
        from vllm.distributed.parallel_state import is_local_first_rank

        return is_local_first_rank()
    # default "process" scope: always log
    return True
106
107
108
109
110
111


class _VllmLogger(Logger):
    """
    Note:
        This class is just to provide type information.
112
        We actually patch the methods directly on the [`logging.Logger`][]
113
114
115
116
        instance to avoid conflicting with other libraries such as
        `intel_extension_for_pytorch.utils._logger`.
    """

117
118
119
    def debug_once(
        self, msg: str, *args: Hashable, scope: LogScope = "process"
    ) -> None:
120
121
122
123
        """
        As [`debug`][logging.Logger.debug], but subsequent calls with
        the same message are silently dropped.
        """
124
125
        if not _should_log_with_scope(scope):
            return
126
127
        _print_debug_once(self, msg, *args)

128
    def info_once(self, msg: str, *args: Hashable, scope: LogScope = "process") -> None:
129
        """
130
131
        As [`info`][logging.Logger.info], but subsequent calls with
        the same message are silently dropped.
132
        """
133
134
        if not _should_log_with_scope(scope):
            return
135
        _print_info_once(self, msg, *args)
136

137
138
139
    def warning_once(
        self, msg: str, *args: Hashable, scope: LogScope = "process"
    ) -> None:
140
        """
141
142
        As [`warning`][logging.Logger.warning], but subsequent calls with
        the same message are silently dropped.
143
        """
144
145
        if not _should_log_with_scope(scope):
            return
146
        _print_warning_once(self, msg, *args)
147
148


149
150
# Pre-defined methods mapping to avoid repeated dictionary creation
_METHODS_TO_PATCH = {
151
152
153
    "debug_once": _VllmLogger.debug_once,
    "info_once": _VllmLogger.info_once,
    "warning_once": _VllmLogger.warning_once,
154
155
156
}


157
def _configure_vllm_root_logger() -> None:
158
    logging_config = dict[str, Any]()
159

Nick Hill's avatar
Nick Hill committed
160
    if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH:
161
162
163
164
        raise RuntimeError(
            "VLLM_CONFIGURE_LOGGING evaluated to false, but "
            "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
            "implies VLLM_CONFIGURE_LOGGING. Please enable "
165
166
            "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH."
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
167

Nick Hill's avatar
Nick Hill committed
168
    if envs.VLLM_CONFIGURE_LOGGING:
169
        logging_config = DEFAULT_LOGGING_CONFIG
Woosuk Kwon's avatar
Woosuk Kwon committed
170

Nick Hill's avatar
Nick Hill committed
171
172
173
174
175
176
177
178
        vllm_handler = logging_config["handlers"]["vllm"]
        # Refresh these values in case env vars have changed.
        vllm_handler["level"] = envs.VLLM_LOGGING_LEVEL
        vllm_handler["stream"] = envs.VLLM_LOGGING_STREAM
        vllm_handler["formatter"] = "vllm_color" if _use_color() else "vllm"

    if envs.VLLM_LOGGING_CONFIG_PATH:
        if not path.exists(envs.VLLM_LOGGING_CONFIG_PATH):
179
180
            raise RuntimeError(
                "Could not load logging config. File does not exist: %s",
Nick Hill's avatar
Nick Hill committed
181
                envs.VLLM_LOGGING_CONFIG_PATH,
182
            )
Nick Hill's avatar
Nick Hill committed
183
        with open(envs.VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
184
            custom_config = json.loads(file.read())
Woosuk Kwon's avatar
Woosuk Kwon committed
185

186
        if not isinstance(custom_config, dict):
187
188
189
190
            raise ValueError(
                "Invalid logging config. Expected dict, got %s.",
                type(custom_config).__name__,
            )
191
        logging_config = custom_config
Woosuk Kwon's avatar
Woosuk Kwon committed
192

193
194
195
196
197
    for formatter in logging_config.get("formatters", {}).values():
        # This provides backwards compatibility after #10134.
        if formatter.get("class") == "vllm.logging.NewLineFormatter":
            formatter["class"] = "vllm.logging_utils.NewLineFormatter"

198
199
    if logging_config:
        dictConfig(logging_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201


202
def init_logger(name: str) -> _VllmLogger:
203
204
205
    """The main purpose of this function is to ensure that loggers are
    retrieved in such a way that we can be sure the root vllm logger has
    already been configured."""
Woosuk Kwon's avatar
Woosuk Kwon committed
206

207
208
    logger = logging.getLogger(name)

209
    for method_name, method in _METHODS_TO_PATCH.items():
210
211
212
        setattr(logger, method_name, MethodType(method, logger))

    return cast(_VllmLogger, logger)
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214


215
# The root logger is initialized when the module is imported.
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
218
_configure_vllm_root_logger()
219
220
221
222
223

logger = init_logger(__name__)


def _trace_calls(log_path, root_dir, frame, event, arg=None):
224
    if event in ["call", "return"]:
225
226
227
228
229
230
231
232
233
        # Extract the filename, line number, function name, and the code object
        filename = frame.f_code.co_filename
        lineno = frame.f_lineno
        func_name = frame.f_code.co_name
        if not filename.startswith(root_dir):
            # only log the functions in the vllm root_dir
            return
        # Log every function call or return
        try:
234
235
236
237
238
239
240
241
242
243
            last_frame = frame.f_back
            if last_frame is not None:
                last_filename = last_frame.f_code.co_filename
                last_lineno = last_frame.f_lineno
                last_func_name = last_frame.f_code.co_name
            else:
                # initial frame
                last_filename = ""
                last_lineno = 0
                last_func_name = ""
244
            with open(log_path, "a") as f:
245
                ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
246
247
248
249
250
251
252
                if event == "call":
                    f.write(
                        f"{ts} Call to"
                        f" {func_name} in {filename}:{lineno}"
                        f" from {last_func_name} in {last_filename}:"
                        f"{last_lineno}\n"
                    )
253
                else:
254
255
256
257
258
259
                    f.write(
                        f"{ts} Return from"
                        f" {func_name} in {filename}:{lineno}"
                        f" to {last_func_name} in {last_filename}:"
                        f"{last_lineno}\n"
                    )
260
261
262
263
264
265
        except NameError:
            # modules are deleted during shutdown
            pass
    return partial(_trace_calls, log_path, root_dir)


266
def enable_trace_function_call(log_file_path: str, root_dir: str | None = None):
267
268
269
270
271
272
273
274
275
276
277
278
279
    """
    Enable tracing of every function call in code under `root_dir`.
    This is useful for debugging hangs or crashes.
    `log_file_path` is the path to the log file.
    `root_dir` is the root directory of the code to trace. If None, it is the
    vllm root directory.

    Note that this call is thread-level, any threads calling this function
    will have the trace enabled. Other threads will not be affected.
    """
    logger.warning(
        "VLLM_TRACE_FUNCTION is enabled. It will record every"
        " function executed by Python. This will slow down the code. It "
280
281
        "is suggested to be used for debugging hang or crashes only."
    )
282
    logger.info("Trace frame log is saved to %s", log_file_path)
283
284
285
286
    if root_dir is None:
        # by default, this is the vllm root directory
        root_dir = os.path.dirname(os.path.dirname(__file__))
    sys.settrace(partial(_trace_calls, log_file_path, root_dir))