base.py 1.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod

from vllm.logger import init_logger

logger = init_logger(__name__)


class ProfilerBase(ABC):
    """
    Abstract base class for all diffusion profilers.
    Defines the common interface used by GPUWorker and DiffusionEngine.
    """

    @abstractmethod
    def start(self, trace_path_template: str) -> str:
        """
        Start profiling.

        Args:
            trace_path_template: Base path (without rank or extension).
                                 e.g. "/tmp/profiles/sdxl_run"

        Returns:
            Full path of the trace file this rank will write.
        """
        pass

    @abstractmethod
    def stop(self) -> str | None:
        """
        Stop profiling and finalize/output the trace.

        Returns:
            Path to the saved trace file, or None if not active.
        """
        pass

    @abstractmethod
    def get_step_context(self):
        """
        Returns a context manager that advances one profiling step.
        Should be a no-op (nullcontext) when profiler is not active.
        """
        pass

    @abstractmethod
    def is_active(self) -> bool:
        """Return True if profiling is currently running."""
        pass

    @classmethod
    def _get_rank(cls) -> int:
        import os

        return int(os.getenv("RANK", "0"))