Unverified Commit ccd3e55e authored by Hank_'s avatar Hank_ Committed by GitHub
Browse files

[Bugfix][plugin] fla crash on plugin (#27322)

parent 01baefe6
......@@ -17,6 +17,7 @@ from typing import Any, Literal
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
logger = logging.getLogger(__name__)
......@@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device = get_available_device() if get_available_device() != "hip" else "cuda"
device_torch_lib = getattr(torch, device)
device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
device_torch_lib = getattr(torch, device, None)
device_platform = _check_platform()
is_amd = device_platform == "amd"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment