uniproc_executor.py 3.01 KB
Newer Older
1
import os
2
from typing import Optional
3

Robert Shaw's avatar
Robert Shaw committed
4
from vllm.config import VllmConfig
5
6
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
7
from vllm.v1.executor.abstract import Executor
8
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
9
10
11
12
13
14
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker

logger = init_logger(__name__)


15
class UniprocExecutor(Executor):
16

Robert Shaw's avatar
Robert Shaw committed
17
18
    def __init__(self, vllm_config: VllmConfig) -> None:
        self.vllm_config = vllm_config
19
20
21
22
23
24
25
26
27
28
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
29

30
        self.worker: Worker = self._create_worker()
31
        self.worker.init_device()
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        self.worker.load_model()

    def _create_worker(
            self,
            local_rank: int = 0,
            rank: int = 0,
            distributed_init_method: Optional[str] = None) -> Worker:
        """Return worker init args for a given rank."""
        # see https://github.com/NVIDIA/nccl/issues/1234
        os.environ['NCCL_CUMEM_ENABLE'] = '0'

        if distributed_init_method is None:
            distributed_init_method = get_distributed_init_method(
                get_ip(), get_open_port())
        return Worker(
47
            vllm_config=self.vllm_config,
48
49
50
51
52
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
        )

53
54
55
    def determine_available_memory(self) -> int:
        """Determine the available memory (in bytes) for KV cache by invoking 
        the underlying worker.
56
        """
57
        return self.worker.determine_available_memory()
58

59
60
61
62
63
64
65
    def get_kv_cache_spec(self) -> KVCacheSpec:
        """Get all kv cache needed by the model by invoking the underlying
        worker.
        """
        return self.worker.get_kv_cache_spec()

    def initialize(self, kv_cache_config: KVCacheConfig) -> None:
66
67
        """Initialize the KV cache by invoking the underlying worker.
        """
68
        self.worker.initialize_cache(kv_cache_config)
69
70
71
72
73
74
75
        self.worker.compile_or_warm_up_model()

    def execute_model(
        self,
        scheduler_output,
    ) -> ModelRunnerOutput:
        output = self.worker.execute_model(scheduler_output)
76
        assert output is not None
77
78
        return output

79
80
81
82
    def profile(self, is_start: bool = True):
        self.worker.profile(is_start)

    def shutdown(self):
83
        pass
84

85
    def check_health(self) -> None:
86
        # UniprocExecutor will always be healthy as long as
87
88
        # it's running.
        return