Commit e3378b20 authored by zhuwenwen's avatar zhuwenwen
Browse files

add gfx

parent 318e2b5a
......@@ -67,20 +67,20 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
# def get_amdgpu_offload_arch():
# command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
# try:
# output = subprocess.check_output([command])
# return output.decode('utf-8').strip()
# except subprocess.CalledProcessError as e:
# error_message = f"Error: {e}"
# raise RuntimeError(error_message) from e
# except FileNotFoundError as e:
# # If the command is not found, print an error message
# error_message = f"The command {command} was not found."
# raise RuntimeError(error_message) from e
# return None
def get_amdgpu_offload_arch():
command = "/opt/dtk-23.10/llvm/bin/amdgpu-offload-arch"
try:
output = subprocess.check_output([command])
return output.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
error_message = f"Error: {e}"
raise RuntimeError(error_message) from e
except FileNotFoundError as e:
# If the command is not found, print an error message
error_message = f"The command {command} was not found."
raise RuntimeError(error_message) from e
return None
def get_hipcc_rocm_version():
......@@ -290,16 +290,17 @@ if _is_cuda():
"nvcc": NVCC_FLAGS_PUNICA,
},
))
# elif _is_hip():
# amd_archs = os.getenv("GPU_ARCHS")
# if amd_archs is None:
# amd_archs = get_amdgpu_offload_arch()
# for arch in amd_archs.split(";"):
# if arch not in ROCM_SUPPORTED_ARCHS:
# raise RuntimeError(
# f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
# f"amdgpu_arch_found: {arch}")
# NVCC_FLAGS += [f"--offload-arch={arch}"]
elif _is_hip():
amd_archs = os.getenv("GPU_ARCHS")
if amd_archs is None:
# amd_archs = get_amdgpu_offload_arch()
amd_archs = "gfx906;gfx926"
for arch in amd_archs.split(";"):
if arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {arch}")
NVCC_FLAGS += [f"--offload-arch={arch}"]
elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment