Commit bc84829b authored by zhuwenwen's avatar zhuwenwen
Browse files

update platforms init

parent ad58e9b3
......@@ -203,8 +203,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
# if current_platform.get_device_capability()[0] != 9:
if torch.cuda.get_device_capability()[0] != 9:
if current_platform.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
......
import torch
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform
......@@ -32,13 +33,15 @@ except Exception:
is_rocm = False
try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
if torch.version.hip is not None:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
# import amdsmi
# amdsmi.amdsmi_init()
# try:
# if len(amdsmi.amdsmi_get_processor_handles()) > 0:
# is_rocm = True
# finally:
# amdsmi.amdsmi_shut_down()
except Exception:
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment