Commit b10621d1 authored by flyingdown's avatar flyingdown
Browse files

修改setup.py,修复编译错误,适配dtk-22.10

parent 86dfa18d
...@@ -109,7 +109,11 @@ struct L2NormFunctor ...@@ -109,7 +109,11 @@ struct L2NormFunctor
} }
}; };
__global__ void cleanup( __global__ void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(1024)
#endif
cleanup(
float* output, float* output,
float* output_per_tensor, float* output_per_tensor,
float* ret, float* ret,
......
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
...@@ -275,6 +275,7 @@ if "--cuda_ext" in sys.argv: ...@@ -275,6 +275,7 @@ if "--cuda_ext" in sys.argv:
CUDAExtension(name='fused_dense_cuda', CUDAExtension(name='fused_dense_cuda',
sources=['csrc/fused_dense.cpp', sources=['csrc/fused_dense.cpp',
'csrc/fused_dense_cuda.cu'], 'csrc/fused_dense_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
nvcc_args_transformer = ['-O3', nvcc_args_transformer = ['-O3',
...@@ -522,8 +523,8 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -522,8 +523,8 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
hipcc_args_mha = ['-O3', hipcc_args_mha = ['-O3',
'-Iapex/contrib/csrc/multihead_attn/cutlass', '-Iapex/contrib/csrc/multihead_attn/cutlass',
'-I/opt/rocm/include/hiprand', '-I' + os.path.join(ROCM_HOME, 'include/hiprand'),
'-I/opt/rocm/include/rocrand', '-I' + os.path.join(ROCM_HOME, 'include/rocrand'),
'-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
if found_Backward_Pass_Guard: if found_Backward_Pass_Guard:
...@@ -559,6 +560,9 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -559,6 +560,9 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv:
if not IS_ROCM_PYTORCH: if not IS_ROCM_PYTORCH:
raise_if_cuda_home_none("--transducer") raise_if_cuda_home_none("--transducer")
hipcc_args_mha = ['-O3',
'-I' + os.path.join(ROCM_HOME, 'include/hiprand'),
'-I' + os.path.join(ROCM_HOME, 'include/rocrand'),] + version_dependent_macros + generator_flag
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="transducer_joint_cuda", name="transducer_joint_cuda",
...@@ -569,7 +573,7 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -569,7 +573,7 @@ if "--transducer" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args={ extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros + generator_flag, "cxx": ["-O3"] + version_dependent_macros + generator_flag,
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + generator_flag) if not IS_ROCM_PYTORCH "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + generator_flag) if not IS_ROCM_PYTORCH
else ["-O3"] + version_dependent_macros + generator_flag, else hipcc_args_mha,
}, },
include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")], include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")],
) )
...@@ -619,6 +623,7 @@ if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -619,6 +623,7 @@ if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu", "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
"apex/contrib/csrc/peer_memory/peer_memory.cpp", "apex/contrib/csrc/peer_memory/peer_memory.cpp",
], ],
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/nccl_p2p")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
) )
) )
...@@ -637,6 +642,7 @@ if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -637,6 +642,7 @@ if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv:
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu", "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp", "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
], ],
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/nccl_p2p")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment