hpu.py 3.27 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from typing import TYPE_CHECKING, Optional
5

6
7
import torch

8
from vllm import envs
9
10
from vllm.logger import init_logger

11
from .interface import Platform, PlatformEnum, _Backend
12

13
14
15
16
17
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

18
19
logger = init_logger(__name__)

20
21
22

class HpuPlatform(Platform):
    _enum = PlatformEnum.HPU
23
    device_name: str = "hpu"
24
    device_type: str = "hpu"
25
    dispatch_key: str = "HPU"
26
    ray_device_key: str = "HPU"
27
    device_control_env_var: str = "HABANA_VISIBLE_MODULES"
28

29
    @classmethod
30
31
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
32
33
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
34
35
        logger.info("Using HPUAttention backend.")
        return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
36

37
38
39
40
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return True

41
42
43
    @staticmethod
    def inference_mode():
        return torch.no_grad()
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

        scheduler_config = vllm_config.scheduler_config
        if scheduler_config.is_multi_step:
            raise NotImplementedError(
                "Multi-step execution is not implemented for HPU")

        if vllm_config.speculative_config is not None:
            raise NotImplementedError(
                "Speculative decoding is not implemented for HPU")

        parallel_config = vllm_config.parallel_config
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
60

61
62
63
64
65
        # NOTE(kzawora): default block size for Gaudi should be 128
        # smaller sizes still work, but very inefficiently
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 128
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        if (parallel_config.distributed_executor_backend == 'mp'
                and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
            if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
                              None) is not None:
                logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork "
                               "might cause application hangs on exit. Using "
                               "VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
                               "as it was explicitly requested.")
            else:
                logger.warning(
                    "On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork "
                    "might cause application hangs on exit. Setting "
                    "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
                    "To override that behavior, please set "
                    "VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
                os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
82

83
84
85
86
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on HPU.")
        return False
87
88
89
90

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"