"vscode:/vscode.git/clone" did not exist on "b10e51989551cd80dd74079429ccf91f0807bd92"
rocm.py 9.26 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import os
from functools import lru_cache, wraps
5
from typing import TYPE_CHECKING, Dict, List, Optional
6
7
8

import torch

9
import vllm.envs as envs
10
11
from vllm.logger import init_logger

12
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
13

14
15
16
17
18
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

19
20
logger = init_logger(__name__)

21
22
23
24
25
26
try:
    from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
                        amdsmi_init, amdsmi_shut_down)
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

27
28
29
30
31
32
33
34
35
36
37
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
    import vllm._rocm_C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._rocm_C with %r", e)

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
    val = os.environ["HIP_VISIBLE_DEVICES"]
    if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
        assert val == cuda_val
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id

98
99
100

class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
101
    device_name: str = "rocm"
102
    device_type: str = "cuda"
103
    dispatch_key: str = "CUDA"
104
    ray_device_key: str = "GPU"
105
106
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
107

108
109
    supported_quantization: list[str] = [
        "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
110
        "fbgemm_fp8", "gguf", "quark", "ptpc_fp8"
111
    ]
112

113
    @classmethod
114
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
115
116
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
117
118
119
        if use_mla:
            logger.info("Using Triton MLA backend.")
            return "vllm.attention.backends.triton_mla.TritonMLABackend"
120
121
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
122
        if envs.VLLM_USE_V1:
123
124
125
            logger.info("Using Triton Attention backend on V1 engine.")
            return ("vllm.v1.attention.backends."
                    "triton_attn.TritonAttentionBackend")
126
127
128
129
130
131
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
132
133
        logger.info("Using ROCmFlashAttention backend.")
        return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
134

135
    @classmethod
136
    @lru_cache(maxsize=8)
137
138
139
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
140

141
    @classmethod
142
    @with_amdsmi_context
143
    @lru_cache(maxsize=8)
144
    def get_device_name(cls, device_id: int = 0) -> str:
145
146
147
        physical_device_id = device_id_to_physical_device_id(device_id)
        handle = amdsmi_get_processor_handles()[physical_device_id]
        return amdsmi_get_gpu_asic_info(handle)["market_name"]
148
149
150
151
152

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
153

154
155
156
157
158
159
160
161
162
163
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

164
165
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
166
167
168
169
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

170
171
172
173
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
174
175
176
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
177
                        "needed) on vLLM V1. Please launch without "
178
179
180
181
                        "--num-scheduler-steps.")
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
182
            elif vllm_config.speculative_config:
183
184
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
185
                        "Speculative decoding is not yet supported on vLLM V1."
186
187
188
189
190
191
                    )
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
192
            else:
193
194
195
196
197
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
198

199
200
201
202
203
204
205
206
207
208
209
210
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

211
212
213
214
215
216
217
218
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                " is not set, enabling VLLM_USE_TRITON_AWQ.")
        envs.VLLM_USE_TRITON_AWQ = True
219
220
221
222

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
223
224
225
226
227
228

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
229
230
        return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
            device)[0]
231
232
233
234

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn