"tests/kernels/attention/test_attention_selector.py" did not exist on "57f09a419c04ecec4718ea9d5be1e6f4a8cc336e"
neuron.py 5.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
import enum
import os
from functools import lru_cache
5
from typing import TYPE_CHECKING, Optional
6

7
from vllm import envs
8
from vllm.logger import init_logger
9
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
10

11
12
from .interface import Platform, PlatformEnum

13
14
15
16
17
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

18
19
logger = init_logger(__name__)

20

21
22
23
24
25
class NeuronFramework(enum.Enum):
    TRANSFORMERS_NEURONX = "transformers-neuronx"
    NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference"


26
27
class NeuronPlatform(Platform):
    _enum = PlatformEnum.NEURON
28
    device_name: str = "neuron"
29
    device_type: str = "neuron"
30
    ray_device_key: str = "neuron_cores"
31
    supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
32
    device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
33
34
35
36

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return "neuron"
37

38
39
40
41
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return False

42
43
44
45
46
47
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        parallel_config = vllm_config.parallel_config
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = \
                "vllm.worker.neuron_worker.NeuronWorker"
48

49
50
51
        if parallel_config.world_size > 1:
            parallel_config.distributed_executor_backend = "uni"

52
53
        assert (vllm_config.lora_config
                is None), "LoRA is not supported for Neuron backend."
54

55
        if vllm_config.cache_config and vllm_config.model_config:
56
57
            # neuron needs block_size = max_model_len
            vllm_config.cache_config.block_size = \
58
                vllm_config.model_config.max_model_len  # type: ignore
59

60
61
62
63
64
65
66
67
68
69
        if vllm_config.model_config and vllm_config.model_config.use_mla:
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

70
71
72
73
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        logger.warning("Pin memory is not supported on Neuron.")
        return False
74

75
76
77
78
79
80
81
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        if envs.VLLM_USE_V1:
            return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator"  # noqa
        else:
            return Platform.get_device_communicator_cls()

82
83
84
    @classmethod
    def use_all_gather(cls) -> bool:
        return True
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    @classmethod
    @lru_cache
    def is_neuronx_distributed_inference(cls) -> bool:
        try:
            import neuronx_distributed_inference
        except ImportError:
            neuronx_distributed_inference = None
        return neuronx_distributed_inference is not None

    @classmethod
    @lru_cache
    def is_transformers_neuronx(cls) -> bool:
        try:
            import transformers_neuronx
        except ImportError:
            transformers_neuronx = None
        return transformers_neuronx is not None

    def get_neuron_framework_to_use(self):
        """Return the specified framework if corresponding installations are
        available.

        If no framework is specified, use neuronx-distributed-inference by
        default.
        If that's unavailable, check and switch to transformers-neuronx.
        """
        if not self.is_neuron():
            raise AssertionError(
                f"Neuron Framework unavailable for platform: {self}")

        tnx_installed = self.is_transformers_neuronx()
        nxd_installed = self.is_neuronx_distributed_inference()

        specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK")
        tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value
        nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value
        if specified_framework == tnx_framework and tnx_installed:
            return self.TRANSFORMERS_NEURONX

        if ((specified_framework == nxd_framework and nxd_installed)
                or (specified_framework is None and nxd_installed)):
            return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE

        if specified_framework is None and tnx_installed:
            return NeuronFramework.TRANSFORMERS_NEURONX

        return None

    def use_neuronx_distributed(self):
        """
        Return True if the framework determined in get_neuron_framework_to_use()
        is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This
        is used to select the Neuron model framework and framework-specific
        configuration to apply during model compilation.
        """
        nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
        return self.get_neuron_framework_to_use() == nxd_framework

    def use_transformers_neuronx(self):
        """
        Return True if the framework determined in get_neuron_framework_to_use()
        is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used
        to select the Neuron model framework and framework-specific
        configuration to apply during model compilation.
        """
        return self.get_neuron_framework_to_use(
        ) == NeuronFramework.TRANSFORMERS_NEURONX