rocm.py 7.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from functools import lru_cache
4
from typing import TYPE_CHECKING, Dict, List, Optional
5
6
7

import torch

8
import vllm.envs as envs
9
10
from vllm.logger import init_logger

11
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
12

13
14
15
16
17
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

18
19
logger = init_logger(__name__)

20
21
22
23
24
25
26
27
28
29
30
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}

56
57
58

class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
59
    device_name: str = "rocm"
60
    device_type: str = "cuda"
61
    dispatch_key: str = "CUDA"
62
    ray_device_key: str = "GPU"
63
64
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
65

66
67
    supported_quantization: list[str] = [
        "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
68
        "fbgemm_fp8", "gguf", "quark", "ptpc_fp8"
69
    ]
70

71
    @classmethod
72
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
73
74
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
75
76
77
        if use_mla:
            logger.info("Using Triton MLA backend.")
            return "vllm.attention.backends.triton_mla.TritonMLABackend"
78
79
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
80
81
82
        if envs.VLLM_USE_V1:
            logger.info("Using ROCm Attention backend on V1 engine.")
            return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
83
84
85
86
87
88
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
89
90
        logger.info("Using ROCmFlashAttention backend.")
        return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
91

92
    @classmethod
93
    @lru_cache(maxsize=8)
94
95
96
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
97

98
    @classmethod
99
    @lru_cache(maxsize=8)
100
    def get_device_name(cls, device_id: int = 0) -> str:
101
102
103
104
105
        # NOTE: When using V1 this function is called when overriding the
        # engine args. Calling torch.cuda.get_device_name(device_id) here
        # will result in the ROCm context being initialized before other
        # processes can be created.
        return "AMD"
106
107
108
109
110

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
111

112
113
114
115
116
117
118
119
120
121
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

122
123
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
124
125
126
127
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

128
129
130
131
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
132
133
134
135
136
137
138
139
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
                        "needed) on VLLM V1. Please launch without "
                        "--num-scheduler-steps.")
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
140
            elif vllm_config.speculative_config:
141
142
143
144
145
146
147
148
149
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Speculative decoding is not yet supported on VLLM V1."
                    )
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
150
            else:
151
152
153
154
155
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
156

157
158
159
160
161
162
163
164
165
166
167
168
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

169
170
171
172
173
174
175
176
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                " is not set, enabling VLLM_USE_TRITON_AWQ.")
        envs.VLLM_USE_TRITON_AWQ = True
177
178
179
180

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
181
182
183
184
185
186

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
187
188
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
            device)[0]