Commit 32b1ccaf authored by maxiao1's avatar maxiao1
Browse files

修改sgl-kernel下的setup_hip.py

parent 251235c2
...@@ -52,30 +52,10 @@ sources = [ ...@@ -52,30 +52,10 @@ sources = [
"csrc/kvcacheio/transfer.cu", "csrc/kvcacheio/transfer.cu",
] ]
cxx_flags = ["-O3", "-w"] cxx_flags = ["-O3"]
libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"] libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
default_target = "gfx942"
amdgpu_target = os.environ.get("AMDGPU_TARGET", default_target)
if torch.cuda.is_available():
try:
amdgpu_target = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
except Exception as e:
print(f"Warning: Failed to detect GPU properties: {e}")
else:
print(f"Warning: torch.cuda not available. Using default target: {amdgpu_target}")
if amdgpu_target not in ["gfx942", "gfx950", "gfx936"]:
print(
f"Warning: Unsupported GPU architecture detected '{amdgpu_target}'. Expected 'gfx942' or 'gfx950'."
)
sys.exit(1)
fp8_macro = (
"-DHIP_FP8_TYPE_FNUZ" if amdgpu_target == "gfx942" else "-DHIP_FP8_TYPE_E4M3"
)
hipcc_flags = [ hipcc_flags = [
"-DNDEBUG", "-DNDEBUG",
...@@ -84,10 +64,8 @@ hipcc_flags = [ ...@@ -84,10 +64,8 @@ hipcc_flags = [
"-Xcompiler", "-Xcompiler",
"-fPIC", "-fPIC",
"-std=c++17", "-std=c++17",
f"--amdgpu-target={amdgpu_target}",
"-DENABLE_BF16", "-DENABLE_BF16",
"-DENABLE_FP8", "-DENABLE_FP8",
fp8_macro,
] ]
ext_modules = [ ext_modules = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment