Commit b10621d1 authored by flyingdown's avatar flyingdown
Browse files

修改setup.py,修复编译错误,适配dtk-22.10

parent 86dfa18d
......@@ -109,7 +109,11 @@ struct L2NormFunctor
}
};
__global__ void cleanup(
__global__ void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(1024)
#endif
cleanup(
float* output,
float* output_per_tensor,
float* ret,
......
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
from setuptools import setup, find_packages
import subprocess
......@@ -275,6 +275,7 @@ if "--cuda_ext" in sys.argv:
CUDAExtension(name='fused_dense_cuda',
sources=['csrc/fused_dense.cpp',
'csrc/fused_dense_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
nvcc_args_transformer = ['-O3',
......@@ -522,8 +523,8 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
hipcc_args_mha = ['-O3',
'-Iapex/contrib/csrc/multihead_attn/cutlass',
'-I/opt/rocm/include/hiprand',
'-I/opt/rocm/include/rocrand',
'-I' + os.path.join(ROCM_HOME, 'include/hiprand'),
'-I' + os.path.join(ROCM_HOME, 'include/rocrand'),
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
if found_Backward_Pass_Guard:
......@@ -559,6 +560,9 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv:
if not IS_ROCM_PYTORCH:
raise_if_cuda_home_none("--transducer")
hipcc_args_mha = ['-O3',
'-I' + os.path.join(ROCM_HOME, 'include/hiprand'),
'-I' + os.path.join(ROCM_HOME, 'include/rocrand'),] + version_dependent_macros + generator_flag
ext_modules.append(
CUDAExtension(
name="transducer_joint_cuda",
......@@ -569,7 +573,7 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros + generator_flag,
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + generator_flag) if not IS_ROCM_PYTORCH
else ["-O3"] + version_dependent_macros + generator_flag,
else hipcc_args_mha,
},
include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")],
)
......@@ -619,6 +623,7 @@ if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
"apex/contrib/csrc/peer_memory/peer_memory.cpp",
],
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/nccl_p2p")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
......@@ -637,6 +642,7 @@ if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv:
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
],
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/nccl_p2p")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment