setup.py 1.92 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
import os
import torch
wangkaixiong's avatar
wangkaixiong committed
3
from setuptools import setup, find_packages
wangkx1's avatar
init  
wangkx1 committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension

library_name = "test_ops"
current_dir = os.path.dirname(os.path.abspath(__file__))

# 源文件列表
sources = [
    os.path.join(current_dir, "test_torch_library_expand.cpp"),
    os.path.join(current_dir, "test_ops_impl.cpp"),
]

# 检查CUDA是否可用
use_cuda = torch.cuda.is_available()
extension = CUDAExtension if use_cuda else CppExtension

if use_cuda:
    import glob
    cuda_files = glob.glob(os.path.join(current_dir, "*.cu"))
    sources.extend(cuda_files)

extra_compile_args = {
    'cxx': ['-O2', '-std=c++17'],
}

if use_cuda:
    extra_compile_args['nvcc'] = ['-O2']

wangkaixiong's avatar
wangkaixiong committed
31
32
33
34
35
36
37
38
39
40
41
42
43
# 创建包目录和 __init__.py
package_dir = os.path.join(current_dir, library_name)
os.makedirs(package_dir, exist_ok=True)

init_py_path = os.path.join(package_dir, "__init__.py")
if not os.path.exists(init_py_path):
    with open(init_py_path, "w") as f:
        f.write("""
from ._C import *

__all__ = ['add_one', 'multiply_by_two']
""")

wangkx1's avatar
init  
wangkx1 committed
44
45
setup(
    name=library_name,
wangkaixiong's avatar
wangkaixiong committed
46
47
48
49
50
51
52
53
54
    version='0.1.1',
    description='Test operations for PyTorch',
    author='Your Name',
    
    # 关键:指定包
    packages=[library_name],
    package_dir={library_name: library_name},
    
    # 扩展模块 - 注意命名格式
wangkx1's avatar
init  
wangkx1 committed
55
56
    ext_modules=[
        extension(
wangkaixiong's avatar
wangkaixiong committed
57
            name=f"{library_name}._C",
wangkx1's avatar
init  
wangkx1 committed
58
59
60
61
62
            sources=sources,
            extra_compile_args=extra_compile_args,
            include_dirs=[current_dir],
        )
    ],
wangkaixiong's avatar
wangkaixiong committed
63
64
    
    # 命令类
wangkx1's avatar
init  
wangkx1 committed
65
66
67
    cmdclass={
        'build_ext': BuildExtension
    },
wangkaixiong's avatar
wangkaixiong committed
68
69
    
    # 依赖
wangkx1's avatar
init  
wangkx1 committed
70
    install_requires=['torch>=1.10.0'],
wangkaixiong's avatar
wangkaixiong committed
71
72
73
74
75
76
77
78
79
    
    # 确保生成正确的 .dist-info
    zip_safe=False,
    # 添加以下参数来避免生成 .egg-info 在当前目录
    options={
        'egg_info': {
            'egg_base': '/tmp'  # 将 egg-info 生成到临时目录
        }
    },
wangkx1's avatar
init  
wangkx1 committed
80
)