import os import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension library_name = "test_ops" # 获取当前目录 current_dir = os.path.dirname(os.path.abspath(__file__)) # 源文件列表 sources = [ os.path.join(current_dir, "test_torch_library_expand.cpp"), os.path.join(current_dir, "test_ops_impl.cpp"), ] # 检查CUDA是否可用 use_cuda = torch.cuda.is_available() extension = CUDAExtension if use_cuda else CppExtension if use_cuda: # 如果有CUDA文件,可以添加 import glob cuda_files = glob.glob(os.path.join(current_dir, "*.cu")) sources.extend(cuda_files) print(f"CUDA files found: {cuda_files}") # 编译参数 extra_compile_args = { 'cxx': ['-O2', '-std=c++17'], } if use_cuda: extra_compile_args['nvcc'] = ['-O2'] setup( name=library_name, version='0.1.0', ext_modules=[ extension( name=library_name, sources=sources, extra_compile_args=extra_compile_args, include_dirs=[current_dir], ) ], cmdclass={ 'build_ext': BuildExtension }, install_requires=['torch>=1.10.0'], options={ 'egg_info': { 'egg_base': '/tmp' # 将 egg-info 生成到临时目录 } }, )