Unverified Commit 4b9971e4 authored by sogalin's avatar sogalin Committed by GitHub
Browse files

Add gfx950 support for sgl-kernel. (#7092)


Co-authored-by: default avatarHAI <hixiao@gmail.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent dcc79d32
...@@ -17,6 +17,7 @@ import platform ...@@ -17,6 +17,7 @@ import platform
import sys import sys
from pathlib import Path from pathlib import Path
import torch
from setuptools import find_packages, setup from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
...@@ -49,6 +50,13 @@ cxx_flags = ["-O3"] ...@@ -49,6 +50,13 @@ cxx_flags = ["-O3"]
libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"] libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
amdgpu_target = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0]
if amdgpu_target not in ["gfx942", "gfx950"]:
print(
f"Warning: Unsupported GPU architecture detected '{amdgpu_target}'. Expected 'gfx942' or 'gfx950'."
)
sys.exit(1)
hipcc_flags = [ hipcc_flags = [
"-DNDEBUG", "-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}", f"-DOPERATOR_NAMESPACE={operator_namespace}",
...@@ -57,7 +65,7 @@ hipcc_flags = [ ...@@ -57,7 +65,7 @@ hipcc_flags = [
"-fPIC", "-fPIC",
"-std=c++17", "-std=c++17",
"-D__HIP_PLATFORM_AMD__=1", "-D__HIP_PLATFORM_AMD__=1",
"--amdgpu-target=gfx942", f"--amdgpu-target={amdgpu_target}",
"-DENABLE_BF16", "-DENABLE_BF16",
"-DENABLE_FP8", "-DENABLE_FP8",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment