Unverified Commit ccd3e55e authored by Hank_'s avatar Hank_ Committed by GitHub
Browse files

[Bugfix][plugin] fla crash on plugin (#27322)

parent 01baefe6
...@@ -17,6 +17,7 @@ from typing import Any, Literal ...@@ -17,6 +17,7 @@ from typing import Any, Literal
import torch import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: ...@@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor. # Therefore, we need to check the triton backend to determine the actual GPU vendor.
device = get_available_device() if get_available_device() != "hip" else "cuda" device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
device_torch_lib = getattr(torch, device) device_torch_lib = getattr(torch, device, None)
device_platform = _check_platform() device_platform = _check_platform()
is_amd = device_platform == "amd" is_amd = device_platform == "amd"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment