"vscode:/vscode.git/clone" did not exist on "3bb4e4311c6da31257e6c8e5b1027ef516e025c8"
Commit bc84829b authored by zhuwenwen's avatar zhuwenwen
Browse files

update platforms init

parent ad58e9b3
...@@ -203,8 +203,7 @@ def which_attn_to_use( ...@@ -203,8 +203,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
# if current_platform.get_device_capability()[0] != 9: if current_platform.get_device_capability()[0] != 9:
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
......
import torch
from .interface import Platform, PlatformEnum, UnspecifiedPlatform from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform current_platform: Platform
...@@ -32,13 +33,15 @@ except Exception: ...@@ -32,13 +33,15 @@ except Exception:
is_rocm = False is_rocm = False
try: try:
import amdsmi if torch.version.hip is not None:
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True is_rocm = True
finally: # import amdsmi
amdsmi.amdsmi_shut_down() # amdsmi.amdsmi_init()
# try:
# if len(amdsmi.amdsmi_get_processor_handles()) > 0:
# is_rocm = True
# finally:
# amdsmi.amdsmi_shut_down()
except Exception: except Exception:
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment