# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py import sys import warnings import os from packaging.version import parse, Version from typing import Optional import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from setuptools import setup, find_packages import subprocess from get_version import get_version def _find_rocm_home() -> Optional[str]: rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') if rocm_home is None: try: pipe_hipcc = subprocess.Popen( ["which hipcc | xargs readlink -f"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) hipcc, _ = pipe_hipcc.communicate() rocm_home = os.path.dirname(os.path.dirname(hipcc.decode(*()).rstrip('\r\n'))) if os.path.basename(rocm_home) == 'hip': rocm_home = os.path.dirname(rocm_home) except Exception: rocm_home = '/opt/rocm' if not os.path.exists(rocm_home): rocm_home = None if rocm_home and torch.version.hip is None: print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'") return rocm_home ROCM_HOME = _find_rocm_home() cc_flag = ["--offload-arch=gfx906","--offload-arch=gfx926","--offload-arch=gfx928","--offload-arch=gfx936"] ext_modules=[] ext_modules.append( CUDAExtension( 'rotary_emb', [ 'rotary.cpp', 'rotary_kernel.cu', ], extra_compile_args={'cxx': ['-O3'],'nvcc': ['-O3'] + cc_flag} ) ) def _get_pytorch_version(): if "PYTORCH_VERSION" in os.environ: return f"{os.environ['PYTORCH_VERSION']}" return torch.__version__ setup( name="rotary_emb", version=get_version(ROCM_HOME), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension} if ext_modules else {}, install_requires=[ f"torch=={_get_pytorch_version()}", ], )