rocm.py 13.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from functools import cache, lru_cache, wraps
5
from typing import TYPE_CHECKING, Dict, List, Optional
6
7
8

import torch

9
import vllm.envs as envs
10
11
from vllm.logger import init_logger

12
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
13

14
if TYPE_CHECKING:
15
    from vllm.config import ModelConfig, VllmConfig
16
else:
17
    ModelConfig = None
18
    VllmConfig = None
19

20
21
logger = init_logger(__name__)

22
try:
23
24
25
    from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)
26
27
28
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

29
30
31
32
33
34
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
35
36
37
38
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
39

40

41
42
43
44
45
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
zhuwenwen's avatar
zhuwenwen committed
46
47
48
49
# _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
#                     "Triton flash attention. For half-precision SWA support, "
#                     "please use CK flash attention by setting "
#                     "`VLLM_USE_TRITON_FLASH_ATTN=0`")
50
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
zhuwenwen's avatar
zhuwenwen committed
51
52
53
54
55
56
    # "Qwen2ForCausalLM":
    # _ROCM_SWA_REASON,
    # "MistralForCausalLM":
    # _ROCM_SWA_REASON,
    # "MixtralForCausalLM":
    # _ROCM_SWA_REASON,
57
58
59
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
zhuwenwen's avatar
zhuwenwen committed
60
61
62
63
    # "Phi3VForCausalLM":
    # ("ROCm Triton flash attention may run into compilation errors due to "
    #  "excessive use of shared memory. If this happens, disable Triton FA "
    #  "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
64
65
}

66
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
67
68
69
70
71
72
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id

101

102
103
104
105
106
107
108
109
110
111
112
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
                                    block_size: int, gqa_ratio: int,
                                    max_seq_len: int,
                                    sliding_window: int) -> bool:

    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_NAVI = "gfx1" in GPU_ARCH
    ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])

    # rocm custom page attention not support on navi (gfx1*)
113
114
    return (ON_MI250_MI300 and not ON_NAVI
            and (sliding_window == 0 or sliding_window == (-1, -1))
115
116
117
118
119
120
121
            and (qtype == torch.half or qtype == torch.bfloat16)
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
            and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)


122
123
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
124
    device_name: str = "rocm"
125
    device_type: str = "cuda"
126
    dispatch_key: str = "CUDA"
127
    ray_device_key: str = "GPU"
128
129
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
130

131
132
    supported_quantization: list[str] = [
        "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
133
        "fbgemm_fp8", "gguf", "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8"
134
    ]
135

136
    @classmethod
137
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
138
139
                             kv_cache_dtype, block_size, use_v1,
                             use_mla) -> str:
140
        if use_mla:
zhuwenwen's avatar
zhuwenwen committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            # logger.info("Using Triton MLA backend.")
            # return "vllm.attention.backends.triton_mla.TritonMLABackend"
            
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
                else:
                    logger.info("Using Triton MLA backend.")
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"
            else:
                from vllm.attention.backends.flashmla import (
                    is_flashmla_supported)
                if not is_flashmla_supported()[0]:
                    logger.warning(
                        "FlashMLA backend is not supported due to %s",
                        is_flashmla_supported()[1])
                elif block_size != 64:
                    logger.warning(
                        "FlashMLA backend is not supported for block size %d"
                        " (currently only supports block size 64).",
                        block_size)
                else:
                    if use_v1:
                        logger.info_once(
                            "Using FlashMLA backend on V1 engine.")
                        return ("vllm.v1.attention.backends.mla."
                                "flashmla.FlashMLABackend")
                    else:
                        logger.info("Using FlashMLA backend.")
                        return ("vllm.attention.backends."
                                "flashmla.FlashMLABackend")
                        
175
176
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
177
        if envs.VLLM_USE_V1:
178
179
180
            logger.info("Using Triton Attention backend on V1 engine.")
            return ("vllm.v1.attention.backends."
                    "triton_attn.TritonAttentionBackend")
181
182
183
184
185
186
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
187
188
        logger.info("Using ROCmFlashAttention backend.")
        return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
189

190
    @classmethod
191
    @lru_cache(maxsize=8)
192
193
194
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
195
196
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
197

zhuwenwen's avatar
zhuwenwen committed
198
    @staticmethod
199
200
    @with_amdsmi_context
    def is_fully_connected(physical_device_ids: List[int]) -> bool:
zhuwenwen's avatar
zhuwenwen committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True
221

222
    @classmethod
223
    @with_amdsmi_context
224
    @lru_cache(maxsize=8)
225
    def get_device_name(cls, device_id: int = 0) -> str:
226
227
        physical_device_id = device_id_to_physical_device_id(device_id)
        handle = amdsmi_get_processor_handles()[physical_device_id]
zhuwenwen's avatar
zhuwenwen committed
228
229
        # return amdsmi_get_gpu_asic_info(handle)["market_name"]
        return torch.cuda.get_device_name(device_id)
230

zhuwenwen's avatar
zhuwenwen committed
231
232
233
234
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
235

236
237
238
239
240
241
242
243
244
245
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

246
247
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
248
249
250
251
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

252
253
254
255
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
256
257
258
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
259
                        "needed) on vLLM V1. Please launch without "
260
261
262
263
                        "--num-scheduler-steps.")
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
264
            elif vllm_config.speculative_config:
265
266
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
267
                        "Speculative decoding is not yet supported on vLLM V1."
268
269
270
271
272
273
                    )
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
274
            else:
275
276
277
278
279
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
280

281
282
283
284
285
286
287
288
289
290
291
292
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

293
294
295
296
297
298
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
299
300
                " is not set, disabling VLLM_USE_TRITON_AWQ.")
            envs.VLLM_USE_TRITON_AWQ = False
301
302
303
304

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
305
306
307
308
309
310

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
zhuwenwen's avatar
zhuwenwen committed
311
312
313
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
        #     device)[0]
        return torch.cuda.max_memory_allocated(device)
314
315
316
317

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
335
336
337
338
339

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        # V1 support on AMD gpus is experimental
        return True
340
341
342
343
344
345

    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        supported_archs = ['gfx94']
zhuwenwen's avatar
zhuwenwen committed
346
        return any(gfx in gcn_arch for gfx in supported_archs)