Unverified Commit a77564e0 authored by EduardDurech's avatar EduardDurech Committed by GitHub
Browse files

CUDA Arch Independent (#8813)

parent 4f9e71df
import ctypes import ctypes
import os import os
import platform import platform
import shutil
from pathlib import Path
import torch import torch
SYSTEM_ARCH = platform.machine()
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12" # copy & modify from torch/utils/cpp_extension.py
if os.path.exists(cuda_path): def _find_cuda_home():
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL) """Find the CUDA install path."""
# Guess #1
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
cuda_home = "/usr/local/cuda"
return cuda_home
if torch.version.hip is None:
cuda_home = Path(_find_cuda_home())
if (cuda_home / "lib").is_dir():
cuda_path = cuda_home / "lib"
elif (cuda_home / "lib64").is_dir():
cuda_path = cuda_home / "lib64"
else:
# Search for 'libcudart.so.12' in subdirectories
for path in cuda_home.rglob("libcudart.so.12"):
cuda_path = path.parent
break
else:
raise RuntimeError("Could not find CUDA lib directory.")
cuda_include = (cuda_path / "libcudart.so.12").resolve()
if cuda_include.exists():
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
from sgl_kernel import common_ops from sgl_kernel import common_ops
from sgl_kernel.allreduce import * from sgl_kernel.allreduce import *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment