"mmdet/models/dense_heads/atss_head.py" did not exist on "fdfe3c4f8ba935ae428a8a496ce57755d5b2ea98"
detail.md 10.4 KB
Newer Older
liuys's avatar
liuys committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
# PyTorch 自定义算子开发与接入指南

本文档基于 `torch_custom_op` 工程,详细介绍如何使用 `TORCH_LIBRARY``TORCH_LIBRARY_IMPL``PyInit_【NAME】` 等技术开发和接入 PyTorch 自定义算子。

## 1. 概述

PyTorch 提供了强大的自定义算子开发能力,通过 `TORCH_LIBRARY``TORCH_LIBRARY_IMPL` 等宏,可以:
- 定义新的算子接口
- 为不同设备(CPU/CUDA)提供专门的实现
- 无缝集成到 PyTorch 生态系统中

## 2. 环境准备

### 2.1 依赖项
- PyTorch 1.7+
- C++ 编译器(支持 C++14)
- CMake (3.18+)
- Python 3.6+

### 2.2 项目结构
推荐的项目结构如下:

```
my_custom_ops/
├── setup.py            # Python 包配置
├── build.sh            # 编译脚本
├── my_ops.h            # 算子声明
├── my_ops_impl.cpp     # 算子实现
├── my_ops.cpp          # 模块初始化和算子定义
└── test_ops.py         # 测试脚本
```

## 3. 开发步骤

### 3.1 定义算子接口

使用 `TORCH_LIBRARY` 宏定义算子接口,在 `.cpp` 文件中:

```cpp
#include <torch/library.h>
#include <ATen/ATen.h>
#include <Python.h>
#include "my_ops.h"

// 定义算子库
TORCH_LIBRARY(my_ops, m) {
    m.def("add_one(Tensor input) -> Tensor");
    m.def("multiply_by_two(Tensor input) -> Tensor");
}

// Python 模块初始化函数
extern "C" {
PyObject *PyInit_my_ops(void) {
    static struct PyModuleDef module_def = {
        PyModuleDef_HEAD_INIT,
        "my_ops",  // 模块名
        "My custom operations module",  // 文档
        -1,
        NULL  // 方法定义
    };
    return PyModule_Create(&module_def);
}
}
```

### 3.2 实现算子逻辑

`.h` 文件中声明算子函数:

```cpp
#pragma once
#include <torch/library.h>
#include <ATen/ATen.h>

namespace my_ops_impl {
    at::Tensor add_one(at::Tensor input);
    at::Tensor multiply_by_two(at::Tensor input);
}
```

`.cpp` 文件中实现算子逻辑并使用 `TORCH_LIBRARY_IMPL` 注册:

```cpp
#include <torch/library.h>
#include <ATen/ATen.h>
#include "my_ops.h"

namespace my_ops_impl {
    // 操作符的具体实现
    at::Tensor add_one(at::Tensor input) {
        return input + 1;
    }
    
    at::Tensor multiply_by_two(at::Tensor input) {
        return input * 2;
    }
}

// 注册 CPU 实现
TORCH_LIBRARY_IMPL(my_ops, CPU, m) {
    m.impl("add_one", &my_ops_impl::add_one);
    m.impl("multiply_by_two", &my_ops_impl::multiply_by_two);
}

// 注册 CUDA 实现(如果有 CUDA)
#ifdef TORCH_HAS_CUDA
TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
    m.impl("add_one", &my_ops_impl::add_one);
    m.impl("multiply_by_two", &my_ops_impl::multiply_by_two);
}
#endif
```

### 3.3 编译配置

#### 3.3.1 使用 CMake 配置

创建 `CMakeLists.txt` 文件:

```cmake
cmake_minimum_required(VERSION 3.18)
project(my_ops)

find_package(Torch REQUIRED)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

add_library(my_ops SHARED my_ops.cpp my_ops_impl.cpp)
target_link_libraries(my_ops "${TORCH_LIBRARIES}")
set_target_properties(my_ops PROPERTIES PREFIX "")
```

#### 3.3.2 使用 setup.py 配置

创建 `setup.py` 文件,支持自动检测 CUDA 并使用相应的扩展:

```python
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension

library_name = "my_ops"
current_dir = os.path.dirname(os.path.abspath(__file__))

# 源文件列表
sources = [
    os.path.join(current_dir, "my_ops.cpp"),
    os.path.join(current_dir, "my_ops_impl.cpp"),
]

# 检查CUDA是否可用
use_cuda = torch.cuda.is_available()
extension = CUDAExtension if use_cuda else CppExtension

if use_cuda:
    import glob
    cuda_files = glob.glob(os.path.join(current_dir, "*.cu"))
    sources.extend(cuda_files)

extra_compile_args = {
    'cxx': ['-O2', '-std=c++17'],
}

if use_cuda:
    extra_compile_args['nvcc'] = ['-O2']

# 创建包目录和 __init__.py
package_dir = os.path.join(current_dir, library_name)
os.makedirs(package_dir, exist_ok=True)

init_py_path = os.path.join(package_dir, "__init__.py")
if not os.path.exists(init_py_path):
    with open(init_py_path, "w") as f:
        f.write("""
from ._C import *

__all__ = ['add_one', 'multiply_by_two']
""")

setup(
    name=library_name,
    version='0.1.0',
    description='Custom operations for PyTorch',
    author='Your Name',
    
    # 关键:指定包
    packages=[library_name],
    package_dir={library_name: library_name},
    
    # 扩展模块 - 注意命名格式
    ext_modules=[
        extension(
            name=f"{library_name}._C",
            sources=sources,
            extra_compile_args=extra_compile_args,
            include_dirs=[current_dir],
        )
    ],
    
    # 命令类
    cmdclass={
        'build_ext': BuildExtension
    },
    
    # 依赖
    install_requires=['torch>=1.10.0'],
    
    # 确保生成正确的 .dist-info
    zip_safe=False,
)
```

#### 3.3.3 编译脚本

创建 `build.sh` 脚本:

```bash
#!/bin/bash
rm -rf build
mkdir -p build
cd build
cmake ..
make -j
cp my_ops.so ..
cd ..
```

## 4. 编译和安装

### 4.1 使用 build.sh 编译

```bash
bash build.sh
```

### 4.2 使用 pip 安装

```bash
pip install --no-build-isolation .
```

## 5. 在 Python 中使用自定义算子

创建测试脚本 `test_ops.py`

```python
import torch
import os
import sys

# 添加当前目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

# 加载自定义操作符
try:
    import my_ops
    print("成功加载 my_ops 扩展")
except ImportError as e:
    print(f"加载扩展失败: {e}")
    print("请先编译 C++ 扩展")
    sys.exit(1)

def test_add_one():
    """测试 add_one 操作符"""
    print("\n测试 add_one 操作符:")
    
    # 创建测试张量
    x = torch.tensor([1.0, 2.0, 3.0])
    print(f"输入: {x}")
    
    # 调用自定义操作符
    y = torch.ops.my_ops.add_one(x)
    print(f"输出: {y}")
    
    # 验证结果
    expected = x + 1
    assert torch.allclose(y, expected), f"结果不匹配: {y} vs {expected}"
    print("✓ 测试通过")

def test_multiply_by_two():
    """测试 multiply_by_two 操作符"""
    print("\n测试 multiply_by_two 操作符:")
    
    # 创建测试张量
    x = torch.tensor([1.0, 2.0, 3.0])
    print(f"输入: {x}")
    
    # 调用自定义操作符
    y = torch.ops.my_ops.multiply_by_two(x)
    print(f"输出: {y}")
    
    # 验证结果
    expected = x * 2
    assert torch.allclose(y, expected), f"结果不匹配: {y} vs {expected}"
    print("✓ 测试通过")

def test_cuda_support():
    """测试 CUDA 支持(如果可用)"""
    if torch.cuda.is_available():
        print("\n测试 CUDA 支持:")
        
        # 创建 CUDA 张量
        x = torch.tensor([1.0, 2.0, 3.0]).cuda()
        print(f"CUDA 输入: {x}")
        
        # 测试 add_one
        y = torch.ops.my_ops.add_one(x)
        print(f"add_one 输出: {y}")
        assert torch.allclose(y, x + 1), f"CUDA add_one 结果不匹配"
        
        # 测试 multiply_by_two
        z = torch.ops.my_ops.multiply_by_two(x)
        print(f"multiply_by_two 输出: {z}")
        assert torch.allclose(z, x * 2), f"CUDA multiply_by_two 结果不匹配"
        
        print("✓ CUDA 测试通过")
    else:
        print("\nCUDA 不可用,跳过 CUDA 测试")

if __name__ == "__main__":
    print("测试自定义算子")
    print("=" * 50)
    
    test_add_one()
    test_multiply_by_two()
    test_cuda_support()
    
    print("\n" + "=" * 50)
    print("所有测试通过!")
```

运行测试:

```bash
python test_ops.py
```

## 6. 常见问题和解决方案

### 6.1 模块初始化函数错误

**错误信息**
```
加载扩展失败: dynamic module does not define module export function (PyInit_my_ops)
```

**解决方案**
确保在 C++ 文件中正确定义了 `PyInit_my_ops` 函数,并且函数名与模块名一致。

### 6.2 CUDA 支持问题

**错误信息**
```
undefined reference to `torch::library::Library::impl(...)`
```

**解决方案**
确保正确定义了 `TORCH_HAS_CUDA` 宏,并且在 CUDA 环境下编译。

### 6.3 算子注册问题

**错误信息**
```
undefined symbol: _ZN5torch8library7Library4implERKNS_6SymbolEPFSt10unique_ptrINS_5autograd15AutogradContextESt14default_deleteIS3_EERKSt13unordered_mapISt14basic_stringIcSt11char_traitsIcESaIcEESt6vectorINS_5IValueESaIS7_EESt4hashIS9_ESt8equal_toIS9_ESaISB_EE
```

**解决方案**
确保使用了正确的 PyTorch 版本,并且编译时链接了正确的 PyTorch 库。

## 7. 高级功能

### 7.1 多设备支持

可以为不同设备(CPU、CUDA)提供不同的实现:

```cpp
// CPU 实现
TORCH_LIBRARY_IMPL(my_ops, CPU, m) {
    m.impl("add_one", &my_ops_impl::add_one_cpu);
}

// CUDA 实现
#ifdef TORCH_HAS_CUDA
TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
    m.impl("add_one", &my_ops_impl::add_one_cuda);
}
#endif
```

### 7.2 复杂算子定义

可以定义更复杂的算子,支持多个输入和输出:

```cpp
TORCH_LIBRARY(my_ops, m) {
    m.def("add_scalar(Tensor input, Scalar scalar) -> Tensor");
    m.def("addmm(Tensor self, Tensor mat1, Tensor mat2, Scalar beta=1, Scalar alpha=1) -> Tensor");
}
```

### 7.3 算子重载

可以为不同类型的输入提供重载:

```cpp
TORCH_LIBRARY(my_ops, m) {
    m.def("add(Tensor a, Tensor b) -> Tensor");
    m.def("add(Tensor a, Scalar b) -> Tensor");
}
```

## 8. 总结

使用 `TORCH_LIBRARY``TORCH_LIBRARY_IMPL` 等技术,我们可以:

1. **定义清晰的算子接口**:使用 `TORCH_LIBRARY` 宏定义算子的签名
2. **实现设备特定的逻辑**:使用 `TORCH_LIBRARY_IMPL` 为不同设备提供专门的实现
3. **无缝集成到 PyTorch**:通过 `PyInit_【NAME】` 函数初始化 Python 模块
4. **灵活编译和安装**:支持使用 CMake 或 setup.py 编译和安装

这种方法为 PyTorch 自定义算子开发提供了一种简洁、灵活的方式,使得我们可以轻松扩展 PyTorch 的功能,满足特定的业务需求。

## 9. 代码示例

完整的代码示例可以参考 `torch_custom_op` 工程中的各个部分:

- **1part**:展示了在单个文件中同时定义和实现算子
- **2part**:展示了将定义和实现分离到不同文件
- **3part**:展示了如何打包为可安装的 Python 包
- **4part**:展示了不同的模块初始化方式

通过这些示例,您可以快速上手 PyTorch 自定义算子的开发和接入。