"vllm/vscode:/vscode.git/clone" did not exist on "2daf23ab0cf00da157b1255faddcf0a269283d36"
neuron_worker.py 4.33 KB
Newer Older
1
"""A Neuron worker class."""
2
from typing import List, Optional, Tuple
3
4
5
6

import torch
import torch.distributed

7
8
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
9
10
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
11
from vllm.model_executor import set_random_seed
12
from vllm.sequence import ExecuteModelRequest
13
from vllm.worker.neuron_model_runner import NeuronModelRunner
14
15
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
                                     LoraNotSupportedWorkerBase, WorkerInput)
16
17


18
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
19
20
21
22
23
24
25
26
27
    """A worker class that executes the model on a group of neuron cores.
    """

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
28
        cache_config: CacheConfig,
29
30
31
        local_rank: int,
        rank: int,
        distributed_init_method: str,
32
33
34
35
36
    ) -> None:
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
37
        self.cache_config = cache_config
38
39
40
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
41
42
43
44
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
45

46
47
48
        self.model_runner: NeuronModelRunner = NeuronModelRunner(
            model_config, parallel_config, scheduler_config, device_config)
        self.is_driver_worker = True
49

50
    def init_device(self) -> None:
51
52
        self.init_distributed_environment()

53
        # Set random seed.
54
55
56
57
58
        set_random_seed(self.model_config.seed)

    def load_model(self):
        self.model_runner.load_model()

59
    def determine_num_available_blocks(self) -> Tuple[int, int]:
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        """Determine the number of available KV blocks.

        Swapping is not yet supported, so always return num_cpu_blocks=0.

        We configure num_gpu_blocks to be equal to max_num_seqs.
        """
        # Set the number of GPU blocks to be the same as the maximum number of
        # sequences that can be processed in a single batch. This is equivalent
        # to schedule without PagedAttention.
        num_gpu_blocks = self.scheduler_config.max_num_seqs

        # Swap not yet supported with Neuron backend.
        num_cpu_blocks = 0

        return num_gpu_blocks, num_cpu_blocks

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache.
        """

        # Different values are not tested.
        assert num_cpu_blocks == 0
        assert num_gpu_blocks == self.scheduler_config.max_num_seqs

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

88
89
90
    @property
    def do_metadata_broadcast(self) -> bool:
        return False
91

92
    @property
93
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
94
        return None
95

96
97
98
99
100
    @torch.inference_mode()
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        return WorkerInput(num_seq_groups=len(
            execute_model_req.seq_group_metadata_list), )
101

102
103
104
    def execute_worker(self, worker_input: WorkerInput) -> None:
        pass

105
106
107
108
109
110
    def get_cache_block_size_bytes(self) -> int:
        """Determine the size in bytes of a cache block.

        This is required for speculative decoding; it is not yet implemented.
        """
        raise NotImplementedError
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    def init_distributed_environment(self):
        """Neuron uses transformers-neuronx for tensor parallelism.

        vLLM still needs the environment inited when TP/PP > 1
        """
        init_distributed_environment(
            world_size=1,
            rank=self.rank,
            local_rank=self.local_rank,
            distributed_init_method=self.distributed_init_method,
            backend="gloo",
        )
        ensure_model_parallel_initialized(
            1,
            1,
        )