"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "c2ed66f0465c0889b06d14b87cf5858958f8f131"
Commit bc84829b authored by zhuwenwen's avatar zhuwenwen
Browse files

update platforms init

parent ad58e9b3
...@@ -203,8 +203,7 @@ def which_attn_to_use( ...@@ -203,8 +203,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
# if current_platform.get_device_capability()[0] != 9: if current_platform.get_device_capability()[0] != 9:
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
......
import torch
from .interface import Platform, PlatformEnum, UnspecifiedPlatform from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform current_platform: Platform
...@@ -32,13 +33,15 @@ except Exception: ...@@ -32,13 +33,15 @@ except Exception:
is_rocm = False is_rocm = False
try: try:
import amdsmi if torch.version.hip is not None:
amdsmi.amdsmi_init() is_rocm = True
try: # import amdsmi
if len(amdsmi.amdsmi_get_processor_handles()) > 0: # amdsmi.amdsmi_init()
is_rocm = True # try:
finally: # if len(amdsmi.amdsmi_get_processor_handles()) > 0:
amdsmi.amdsmi_shut_down() # is_rocm = True
# finally:
# amdsmi.amdsmi_shut_down()
except Exception: except Exception:
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment