# SPDX-License-Identifier: Apache-2.0 from functools import lru_cache from typing import TYPE_CHECKING, Dict, List, Optional import torch import vllm.envs as envs from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None logger = init_logger(__name__) try: import vllm._C # noqa: F401 except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) # import custom ops, trigger op registration try: import vllm._rocm_C # noqa: F401 except ImportError as e: logger.warning("Failed to import from vllm._rocm_C with %r", e) # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: List[str] = [] # Models partially supported by ROCm. # Architecture -> Reason. _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " "Triton flash attention. For half-precision SWA support, " "please use CK flash attention by setting " "`VLLM_USE_TRITON_FLASH_ATTN=0`") _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { "Qwen2ForCausalLM": _ROCM_SWA_REASON, "MistralForCausalLM": _ROCM_SWA_REASON, "MixtralForCausalLM": _ROCM_SWA_REASON, "PaliGemmaForConditionalGeneration": ("ROCm flash attention does not yet " "fully support 32-bit precision on PaliGemma"), "Phi3VForCausalLM": ("ROCm Triton flash attention may run into compilation errors due to " "excessive use of shared memory. If this happens, disable Triton FA " "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } class RocmPlatform(Platform): _enum = PlatformEnum.ROCM device_name: str = "rocm" device_type: str = "cuda" dispatch_key: str = "CUDA" ray_device_key: str = "GPU" # rocm shares the same device control env var as CUDA device_control_env_var: str = "CUDA_VISIBLE_DEVICES" supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" ] @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_mla: logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: logger.info("Using ROCm Attention backend on V1 engine.") return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): # not Instinct series GPUs. logger.info("flash_attn is not supported on NAVI GPUs.") else: logger.info("%s is not supported in AMD GPUs.", selected_backend) logger.info("Using ROCmFlashAttention backend.") return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 @classmethod @lru_cache(maxsize=8) def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: # NOTE: When using V1 this function is called when overriding the # engine args. Calling torch.cuda.get_device_name(device_id) here # will result in the ROCm context being initialized before other # processes can be created. return "AMD" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: if enforce_eager: logger.warning( "To see benefits of async output processing, enable CUDA " "graph. Since, enforce-eager is enabled, async output " "processor cannot be used") return False return True @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: if envs.VLLM_USE_V1: raise NotImplementedError( "Multi-step scheduling is not supported (and not " "needed) on VLLM V1. Please launch without " "--num-scheduler-steps.") else: parallel_config.worker_cls = \ "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: if envs.VLLM_USE_V1: raise NotImplementedError( "Speculative decoding is not yet supported on VLLM V1." ) else: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" parallel_config.sd_worker_cls = \ "vllm.worker.worker.Worker" else: if envs.VLLM_USE_V1: parallel_config.worker_cls = \ "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError(f"Model architecture '{model_arch}' is not " "supported by ROCm for now.") if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] logger.warning( "Model architecture '%s' is partially " "supported by ROCm: %s", model_arch, msg) @classmethod def verify_quantization(cls, quant: str) -> None: super().verify_quantization(quant) if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: logger.warning( "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None ) -> float: torch.cuda.reset_peak_memory_stats(device) return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( device)[0]