neuron_worker.py 7.48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A Neuron worker class."""
4
import os
5
from typing import List, Optional, Set, Tuple
6
7
8

import torch.distributed

9
from vllm.config import VllmConfig
10
11
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.model_executor import set_random_seed
15
16
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
17
from vllm.sequence import ExecuteModelRequest
18
from vllm.worker.cache_engine import CacheEngine
19
from vllm.worker.neuron_model_runner import NeuronModelRunner
20
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
21
                                     WorkerInput)
22

23
logger = init_logger(__name__)
24
25


26
class NeuronWorker(LocalOrDistributedWorkerBase):
27
28
29
    """A worker class that executes the model on a group of neuron cores.
    """

30
31
32
33
34
35
36
37
    model_runner: NeuronModelRunner

    def __init__(self,
                 vllm_config: VllmConfig,
                 local_rank: int,
                 rank: int,
                 distributed_init_method: str,
                 is_driver_worker: bool = False) -> None:
38
        WorkerBase.__init__(self, vllm_config=vllm_config)
39
40
41
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
42
        self.is_driver_worker = is_driver_worker
43
        self.lora_config = vllm_config.lora_config
44

45
46
47
48
        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        neuron_framework = current_platform.get_neuron_framework_to_use()
        if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX:
            self.model_runner = self.get_tnx_model_runner(vllm_config)
        elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE:
            self.model_runner = self.get_neuronx_distributed_model_runner(
                vllm_config)
        else:
            raise NotImplementedError(
                "Specified framework" +
                f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" +
                " is either not installed or not supported." +
                " Supported frameworks: " +
                "[transformers-neuronx, neuronx-distributed-inference]")

    def get_tnx_model_runner(self, vllm_config):
65
66
67
        assert (self.lora_config
                is None), ("LoRA is not supported for TransformersNeuronX "
                           "framework.")
68
69
70
71
72
73
74
75
76
77
78
79
80
        from vllm.worker.multi_step_neuron_model_runner import (
            MultiStepNeuronModelRunner)
        if self.speculative_config is not None:
            return MultiStepNeuronModelRunner(vllm_config=vllm_config)
        else:
            return NeuronModelRunner(vllm_config=vllm_config)

    def get_neuronx_distributed_model_runner(self, vllm_config):
        from vllm.worker.multi_step_neuronx_distributed_model_runner import (
            MultiStepNeuronxDistributedModelRunner)
        from vllm.worker.neuronx_distributed_model_runner import (
            NeuronxDistributedModelRunner)
        if self.speculative_config is not None:
81
82
            assert (self.lora_config
                    is None), "LoRA is not supported for Speculative Decoding"
83
84
85
86
            return MultiStepNeuronxDistributedModelRunner(
                vllm_config=vllm_config)
        else:
            return NeuronxDistributedModelRunner(vllm_config=vllm_config)
87

88
    def init_device(self) -> None:
89
90
        self.init_distributed_environment()

91
        # Set random seed.
92
93
94
95
96
        set_random_seed(self.model_config.seed)

    def load_model(self):
        self.model_runner.load_model()

97
    def determine_num_available_blocks(self) -> Tuple[int, int]:
98
99
100
101
102
103
104
105
106
        """Determine the number of available KV blocks.

        Swapping is not yet supported, so always return num_cpu_blocks=0.

        We configure num_gpu_blocks to be equal to max_num_seqs.
        """
        # Set the number of GPU blocks to be the same as the maximum number of
        # sequences that can be processed in a single batch. This is equivalent
        # to schedule without PagedAttention.
107
        num_gpu_blocks = self.scheduler_config.max_num_seqs + 1
108
109
110
111
112
113
114
115
116
117
118
119
120

        # Swap not yet supported with Neuron backend.
        num_cpu_blocks = 0

        return num_gpu_blocks, num_cpu_blocks

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache.
        """

        # Different values are not tested.
        assert num_cpu_blocks == 0
121
        assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1
122
123
124
125

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

126
127
128
    @property
    def do_metadata_broadcast(self) -> bool:
        return False
129

130
    @property
131
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
132
        return None
133
134
135
136
    
    @property
    def cache_engines(self) -> Optional[List[CacheEngine]]:
        return None
137

138
139
140
141
142
    @torch.inference_mode()
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        return WorkerInput(num_seq_groups=len(
            execute_model_req.seq_group_metadata_list), )
143

144
145
146
    def execute_worker(self, worker_input: WorkerInput) -> None:
        pass

147
148
149
150
151
152
    def get_cache_block_size_bytes(self) -> int:
        """Determine the size in bytes of a cache block.

        This is required for speculative decoding; it is not yet implemented.
        """
        raise NotImplementedError
153
154
155

    def init_distributed_environment(self):
        """Neuron uses transformers-neuronx for tensor parallelism.
156
157

        vLLM still needs the environment initialized when TP/PP > 1
158
159
160
        """
        init_distributed_environment(
            world_size=1,
161
162
            rank=self.rank,
            local_rank=self.local_rank,
163
            distributed_init_method=self.distributed_init_method,
164
            backend=current_platform.dist_backend,
165
        )
166

167
168
169
170
        ensure_model_parallel_initialized(
            1,
            1,
        )
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if current_platform.use_transformers_neuronx():
            raise NotImplementedError(
                f"{type(self)} does not support LoRA with Neuron Framework "
                f"Transformers NeuronX")
        return self.model_runner.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if current_platform.use_transformers_neuronx():
            raise NotImplementedError(
                f"{type(self)} does not support LoRA with Neuron Framework "
                f"Transformers NeuronX")
        return self.model_runner.remove_lora(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        if current_platform.use_transformers_neuronx():
            raise NotImplementedError(
                f"{type(self)} does not support LoRA with Neuron Framework "
                f"Transformers NeuronX")
        return self.model_runner.pin_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if current_platform.use_transformers_neuronx():
            raise NotImplementedError(
                f"{type(self)} does not support LoRA with Neuron Framework "
                f"Transformers NeuronX")
        return self.model_runner.list_loras()