worker_base.py 5.66 KB
Newer Older
1
import datetime
2
3
import importlib
import os
4
5
import tempfile
import threading
6
from abc import ABC, abstractmethod
7
from typing import Dict, List, Set, Tuple
8

9
from vllm.logger import enable_trace_function_call, init_logger
10
11
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
12
from vllm.utils import get_vllm_instance_id, update_environment_variables
13
14

logger = init_logger(__name__)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


class WorkerBase(ABC):
    """Worker interface that allows vLLM to cleanly separate implementations for
    different hardware.
    """

    @abstractmethod
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    @abstractmethod
30
    def determine_num_available_blocks(self) -> Tuple[int, int]:
31
32
33
34
35
36
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        The implementation may run profiling or other heuristics to determine
        the size of caches.

37
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

    @abstractmethod
52
53
54
55
56
57
58
    def execute_model(
            self, seq_group_metadata_list: List[SequenceGroupMetadata],
            blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
                                                                        int],
            blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
59
60
61
        raise NotImplementedError

    @abstractmethod
62
    def get_cache_block_size_bytes(self) -> int:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

    @abstractmethod
77
    def list_loras(self) -> Set[int]:
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        raise NotImplementedError


class LoraNotSupportedWorkerBase(WorkerBase):
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

92
    def list_loras(self) -> Set[int]:
93
        raise ValueError(f"{type(self)} does not support LoRA")
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120


class WorkerWrapperBase:
    """
    The whole point of this class is to lazily initialize the worker.
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

    def __init__(self,
                 worker_module_name=None,
                 worker_class_name=None) -> None:
        self.worker_module_name = worker_module_name
        self.worker_class_name = worker_class_name
        self.worker = None

    def update_environment_variables(self, envs: Dict[str, str]) -> None:
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

    def init_worker(self, *args, **kwargs):
        """
121
122
        Actual initialization of the worker class, and set up
       function tracing if required.
123
124
        Arguments are passed to the worker class constructor.
        """
125
126
127
128
129
130
131
132
133
134
        if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
            tmp_dir = tempfile.gettempdir()
            filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                        f"_thread_{threading.get_ident()}_"
                        f"at_{datetime.datetime.now()}.log").replace(" ", "_")
            log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
                                    filename)
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            enable_trace_function_call(log_path)

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        mod = importlib.import_module(self.worker_module_name)
        worker_class = getattr(mod, self.worker_class_name)
        self.worker = worker_class(*args, **kwargs)

    def execute_method(self, method, *args, **kwargs):
        try:
            if hasattr(self, method):
                executor = getattr(self, method)
            else:
                executor = getattr(self.worker, method)
            return executor(*args, **kwargs)
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
            msg = (f"Error executing method {method}. "
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e