setup.py 4.17 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# ------------------------------------------------------------------------------------------------
# Deformable Convolution v4
# Copyright (c) 2024 OpenGVLab
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------

import os
import glob
import torch

# 导入打包相关库
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension

# 定义获取扩展的函数(保持原样,供非打包模式使用)
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, "src")

    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
    source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
    source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))

    sources = main_file + source_cpu
    extension = CppExtension
    extra_compile_args = {"cxx": []}
    define_macros = []

    if torch.cuda.is_available() and CUDA_HOME is not None:
        extension = CUDAExtension
        sources += source_cuda
        define_macros += [("WITH_CUDA", None)]
        extra_compile_args["nvcc"] = [
            "-DCUDA_HAS_FP16=1",
            "-D__CUDA_NO_HALF_OPERATORS__",
            "-D__CUDA_NO_HALF_CONVERSIONS__",
            "-D__CUDA_NO_HALF2_OPERATORS__",
            "-O3",
        ]
    else:
        raise NotImplementedError('Cuda is not available')

    sources = [os.path.join(extensions_dir, s) for s in sources]
    include_dirs = [extensions_dir]
    ext_modules = [
        extension(
            "DCNv4.ext",  # 注意:这里保持原模块名,方便后面替换
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
        )
    ]
    return ext_modules

# --- 核心修改逻辑 ---
# 检查是否是构建 Wheel 的模式
# 如果是构建 Wheel,我们不编译,而是将现有的 .so 作为包数据处理
# 注意:setuptools 打包扩展模块和打包数据文件的逻辑是冲突的,所以我们需要在构建 Wheel 时禁用 ext_modules

if __name__ == "__main__":
    # 检查环境变量,决定是否跳过编译
    # 你也可以直接写一个布尔值,或者检查某个文件是否存在
    build_so = int(os.getenv("DCNv4_BUILD_SO", "0"))
    
    # 准备参数
    kwargs = {
        "name": "DCNv4",
        "version": "1.0.0.post2",
        "author": "Yuwen Xiong, Feng Wang",
        "url": "",
        "description": "PyTorch Wrapper for CUDA Functions of DCNv4",
        "packages": ['DCNv4', 'DCNv4/functions', 'DCNv4/modules'],
        "package_data": {
            "DCNv4": ["ext.so"], # 假设 ext.so 生成在 DCNv4 目录下
            # "DCNv4": ["ext.cpython-310-x86_64-linux-gnu.so"], # 假设 ext.so 生成在 DCNv4 目录下
        },
        "cmdclass": {"build_ext": torch.utils.cpp_extension.BuildExtension},
        # 确保生成正确的 .dist-info
        "zip_safe": False,
        # 添加以下参数来避免生成 .egg-info 在当前目录
        "options": {
            'egg_info': {
                'egg_base': '/tmp'  # 将 egg-info 生成到临时目录
            }
        },
    }

    if build_so:
        # 正常开发模式,进行编译
        kwargs["ext_modules"] = get_extensions()
    else:
        print("=== BUILD WHEEL MODE: Skipping compilation, using existing ext.so ===")
        # 在构建 Wheel 时,不要传入 ext_modules
        # 我们依赖 MANIFEST.in 或 package_data 将 .so 文件包含进去
        # 但是 setuptools 的 bdist_wheel 默认会忽略 .so,所以我们需要确保 .so 在包目录里
        # 这里我们不传入 ext_modules,而是依靠外部脚本或 MANIFEST.in
        # 更简单的方法:直接在 setup 里不写 ext_modules,确保 .so 已经在 DCNv4/ 目录下
        kwargs["ext_modules"] = [] # 强制不编译

    setup(**kwargs)