worker_base.py 5.48 KB
Newer Older
1
2
import importlib
import os
3
from abc import ABC, abstractmethod
4
from typing import Dict, List, Optional, Set, Tuple
5

6
from vllm.logger import init_logger
7
from vllm.lora.request import LoRARequest
8
from vllm.sequence import ExecuteModelRequest, SamplerOutput
9
10
from vllm.utils import (enable_trace_function_call_for_thread,
                        update_environment_variables)
11
12

logger = init_logger(__name__)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


class WorkerBase(ABC):
    """Worker interface that allows vLLM to cleanly separate implementations for
    different hardware.
    """

    @abstractmethod
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    @abstractmethod
28
    def determine_num_available_blocks(self) -> Tuple[int, int]:
29
30
31
32
33
34
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        The implementation may run profiling or other heuristics to determine
        the size of caches.

35
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

    @abstractmethod
50
    def execute_model(
51
52
53
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
54
55
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
56
57
58
        raise NotImplementedError

    @abstractmethod
59
    def get_cache_block_size_bytes(self) -> int:
60
61
62
63
64
65
66
67
68
69
70
71
72
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

73
74
75
76
    @abstractmethod
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

77
    @abstractmethod
78
    def list_loras(self) -> Set[int]:
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        raise NotImplementedError


class LoraNotSupportedWorkerBase(WorkerBase):
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

93
94
95
96
    def pin_lora(self, lora_id: int) -> bool:
        return ValueError(
            f"{type(self)} does not support LoRA")  # type: ignore

97
    def list_loras(self) -> Set[int]:
98
        raise ValueError(f"{type(self)} does not support LoRA")
99
100
101
102
103
104
105
106
107
108
109


class WorkerWrapperBase:
    """
    The whole point of this class is to lazily initialize the worker.
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

    def __init__(self,
110
111
                 worker_module_name: str,
                 worker_class_name: str,
112
                 trust_remote_code: bool = False) -> None:
113
114
115
        self.worker_module_name = worker_module_name
        self.worker_class_name = worker_class_name
        self.worker = None
116
117
118
119
        if trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
120

121
122
    @staticmethod
    def update_environment_variables(envs: Dict[str, str]) -> None:
123
124
125
126
127
128
129
130
131
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

    def init_worker(self, *args, **kwargs):
        """
132
        Here we inject some common logic before initializing the worker.
133
134
        Arguments are passed to the worker class constructor.
        """
135
        enable_trace_function_call_for_thread()
136

137
138
139
        # see https://github.com/NVIDIA/nccl/issues/1234
        os.environ['NCCL_CUMEM_ENABLE'] = '0'

140
141
142
143
144
145
        mod = importlib.import_module(self.worker_module_name)
        worker_class = getattr(mod, self.worker_class_name)
        self.worker = worker_class(*args, **kwargs)

    def execute_method(self, method, *args, **kwargs):
        try:
146
147
            target = self if self.worker is None else self.worker
            executor = getattr(target, method)
148
149
150
151
152
153
154
155
156
157
            return executor(*args, **kwargs)
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
            msg = (f"Error executing method {method}. "
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e