"tests/entrypoints/pooling/embed/test_online.py" did not exist on "9edca6bf8fa81e2dc678be68e9cdcede572947c1"
Unverified Commit 3a63be0f authored by David Ramon Prados's avatar David Ramon Prados Committed by GitHub
Browse files

Support custom URI schemes and trace handlers for profiler (#32393)

parent 803e3f3f
......@@ -3,6 +3,7 @@
import pytest
from vllm.config import ProfilerConfig
from vllm.config.profiler import _is_uri_path
from vllm.profiler.wrapper import WorkerProfiler
......@@ -202,3 +203,36 @@ def test_mixed_delay_and_stop(default_profiler_config):
profiler.step()
assert profiler.start_call_count == 0
class TestIsUriPath:
"""Tests for the _is_uri_path helper function."""
@pytest.mark.parametrize(
"path,expected",
[
# Valid URI schemes - should return True
("gs://bucket/path", True),
("s3://bucket/path", True),
("hdfs://cluster/path", True),
("abfs://container/path", True),
("http://example.com/path", True),
("https://example.com/path", True),
# Local paths - should return False
("/tmp/local/path", False),
("./relative/path", False),
("relative/path", False),
("/absolute/path", False),
# Windows drive letters - should return False (single char scheme)
("C://windows/path", False),
("D://drive/path", False),
# Edge cases
("", False),
("no-scheme", False),
("scheme-no-slashes:", False),
("://no-scheme", False),
],
)
def test_is_uri_path(self, path, expected):
"""Test that _is_uri_path correctly identifies URI vs local paths."""
assert _is_uri_path(path) == expected
......@@ -18,6 +18,20 @@ logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
def _is_uri_path(path: str) -> bool:
"""Check if path is a URI (scheme://...), excluding Windows drive letters.
Supports custom URI schemes like gs://, s3://, hdfs://, etc.
These paths should not be converted to absolute paths.
"""
if "://" in path:
scheme = path.split("://")[0]
# Windows drive letters are single characters (e.g., C://)
# Valid URI schemes have more than one character
return len(scheme) > 1
return False
@config
@dataclass
class ProfilerConfig:
......@@ -185,15 +199,9 @@ class ProfilerConfig:
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
# Support any URI scheme (gs://, s3://, hdfs://, etc.)
# These paths should not be converted to absolute paths
if profiler_dir and not _is_uri_path(profiler_dir):
self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
return self
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import nullcontext
from typing import Literal
......@@ -9,6 +10,7 @@ import torch
from typing_extensions import override
from vllm.config import ProfilerConfig
from vllm.config.profiler import _is_uri_path
from vllm.logger import init_logger
logger = init_logger(__name__)
......@@ -151,6 +153,7 @@ class TorchProfilerWrapper(WorkerProfiler):
worker_name: str,
local_rank: int,
activities: list[TorchProfilerActivity],
on_trace_ready: Callable[[torch.profiler.profile], None] | None = None,
) -> None:
super().__init__(profiler_config)
......@@ -172,6 +175,17 @@ class TorchProfilerWrapper(WorkerProfiler):
profiler_config.torch_profiler_with_flops,
)
# Determine trace handler: use custom handler if provided,
# otherwise default to tensorboard trace handler
if on_trace_ready is not None:
trace_handler = on_trace_ready
else:
trace_handler = torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=profiler_config.torch_profiler_use_gzip,
)
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
self.profiler = torch.profiler.profile(
activities=[TorchProfilerActivityMap[activity] for activity in activities],
......@@ -179,11 +193,7 @@ class TorchProfilerWrapper(WorkerProfiler):
profile_memory=profiler_config.torch_profiler_with_memory,
with_stack=profiler_config.torch_profiler_with_stack,
with_flops=profiler_config.torch_profiler_with_flops,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=profiler_config.torch_profiler_use_gzip,
),
on_trace_ready=trace_handler,
)
@override
......@@ -198,10 +208,13 @@ class TorchProfilerWrapper(WorkerProfiler):
rank = self.local_rank
if profiler_config.torch_profiler_dump_cuda_time_total:
profiler_dir = profiler_config.torch_profiler_dir
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
sort_key = "self_cuda_time_total"
table = self.profiler.key_averages().table(sort_by=sort_key)
# Skip file write for URI paths (gs://, s3://, etc.)
# as standard file I/O doesn't work with URI schemes
if not _is_uri_path(profiler_dir):
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
with open(profiler_out_file, "w") as f:
print(table, file=f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment