import os import torch from setuptools import setup, find_packages from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension library_name = "test_ops" current_dir = os.path.dirname(os.path.abspath(__file__)) # 源文件列表 sources = [ os.path.join(current_dir, "test_torch_library_expand.cpp"), os.path.join(current_dir, "test_ops_impl.cpp"), ] # 检查CUDA是否可用 use_cuda = torch.cuda.is_available() extension = CUDAExtension if use_cuda else CppExtension if use_cuda: import glob cuda_files = glob.glob(os.path.join(current_dir, "*.cu")) sources.extend(cuda_files) extra_compile_args = { 'cxx': ['-O2', '-std=c++17'], } if use_cuda: extra_compile_args['nvcc'] = ['-O2'] # 创建包目录和 __init__.py package_dir = os.path.join(current_dir, library_name) os.makedirs(package_dir, exist_ok=True) init_py_path = os.path.join(package_dir, "__init__.py") if not os.path.exists(init_py_path): with open(init_py_path, "w") as f: f.write(""" from ._C import * __all__ = ['add_one', 'multiply_by_two'] """) setup( name=library_name, version='0.1.1', description='Test operations for PyTorch', author='Your Name', # 关键:指定包 packages=[library_name], package_dir={library_name: library_name}, # 扩展模块 - 注意命名格式 ext_modules=[ extension( name=f"{library_name}._C", sources=sources, extra_compile_args=extra_compile_args, include_dirs=[current_dir], ) ], # 命令类 cmdclass={ 'build_ext': BuildExtension }, # 依赖 install_requires=['torch>=1.10.0'], # 确保生成正确的 .dist-info zip_safe=False, # 添加以下参数来避免生成 .egg-info 在当前目录 options={ 'egg_info': { 'egg_base': '/tmp' # 将 egg-info 生成到临时目录 } }, )