Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tsoc
torch_custom_op
Commits
77afe6c8
Commit
77afe6c8
authored
Feb 27, 2026
by
liuys
🏸
Browse files
add detail
parent
a22ec42e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
433 additions
and
0 deletions
+433
-0
detail.md
detail.md
+433
-0
No files found.
detail.md
0 → 100644
View file @
77afe6c8
# PyTorch 自定义算子开发与接入指南
本文档基于
`torch_custom_op`
工程,详细介绍如何使用
`TORCH_LIBRARY`
、
`TORCH_LIBRARY_IMPL`
和
`PyInit_【NAME】`
等技术开发和接入 PyTorch 自定义算子。
## 1. 概述
PyTorch 提供了强大的自定义算子开发能力,通过
`TORCH_LIBRARY`
和
`TORCH_LIBRARY_IMPL`
等宏,可以:
-
定义新的算子接口
-
为不同设备(CPU/CUDA)提供专门的实现
-
无缝集成到 PyTorch 生态系统中
## 2. 环境准备
### 2.1 依赖项
-
PyTorch 1.7+
-
C++ 编译器(支持 C++14)
-
CMake (3.18+)
-
Python 3.6+
### 2.2 项目结构
推荐的项目结构如下:
```
my_custom_ops/
├── setup.py # Python 包配置
├── build.sh # 编译脚本
├── my_ops.h # 算子声明
├── my_ops_impl.cpp # 算子实现
├── my_ops.cpp # 模块初始化和算子定义
└── test_ops.py # 测试脚本
```
## 3. 开发步骤
### 3.1 定义算子接口
使用
`TORCH_LIBRARY`
宏定义算子接口,在
`.cpp`
文件中:
```
cpp
#include <torch/library.h>
#include <ATen/ATen.h>
#include <Python.h>
#include "my_ops.h"
// 定义算子库
TORCH_LIBRARY
(
my_ops
,
m
)
{
m
.
def
(
"add_one(Tensor input) -> Tensor"
);
m
.
def
(
"multiply_by_two(Tensor input) -> Tensor"
);
}
// Python 模块初始化函数
extern
"C"
{
PyObject
*
PyInit_my_ops
(
void
)
{
static
struct
PyModuleDef
module_def
=
{
PyModuleDef_HEAD_INIT
,
"my_ops"
,
// 模块名
"My custom operations module"
,
// 文档
-
1
,
NULL
// 方法定义
};
return
PyModule_Create
(
&
module_def
);
}
}
```
### 3.2 实现算子逻辑
在
`.h`
文件中声明算子函数:
```
cpp
#pragma once
#include <torch/library.h>
#include <ATen/ATen.h>
namespace
my_ops_impl
{
at
::
Tensor
add_one
(
at
::
Tensor
input
);
at
::
Tensor
multiply_by_two
(
at
::
Tensor
input
);
}
```
在
`.cpp`
文件中实现算子逻辑并使用
`TORCH_LIBRARY_IMPL`
注册:
```
cpp
#include <torch/library.h>
#include <ATen/ATen.h>
#include "my_ops.h"
namespace
my_ops_impl
{
// 操作符的具体实现
at
::
Tensor
add_one
(
at
::
Tensor
input
)
{
return
input
+
1
;
}
at
::
Tensor
multiply_by_two
(
at
::
Tensor
input
)
{
return
input
*
2
;
}
}
// 注册 CPU 实现
TORCH_LIBRARY_IMPL
(
my_ops
,
CPU
,
m
)
{
m
.
impl
(
"add_one"
,
&
my_ops_impl
::
add_one
);
m
.
impl
(
"multiply_by_two"
,
&
my_ops_impl
::
multiply_by_two
);
}
// 注册 CUDA 实现(如果有 CUDA)
#ifdef TORCH_HAS_CUDA
TORCH_LIBRARY_IMPL
(
my_ops
,
CUDA
,
m
)
{
m
.
impl
(
"add_one"
,
&
my_ops_impl
::
add_one
);
m
.
impl
(
"multiply_by_two"
,
&
my_ops_impl
::
multiply_by_two
);
}
#endif
```
### 3.3 编译配置
#### 3.3.1 使用 CMake 配置
创建
`CMakeLists.txt`
文件:
```
cmake
cmake_minimum_required
(
VERSION 3.18
)
project
(
my_ops
)
find_package
(
Torch REQUIRED
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
add_library
(
my_ops SHARED my_ops.cpp my_ops_impl.cpp
)
target_link_libraries
(
my_ops
"
${
TORCH_LIBRARIES
}
"
)
set_target_properties
(
my_ops PROPERTIES PREFIX
""
)
```
#### 3.3.2 使用 setup.py 配置
创建
`setup.py`
文件,支持自动检测 CUDA 并使用相应的扩展:
```
python
import
os
import
torch
from
setuptools
import
setup
,
find_packages
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
library_name
=
"my_ops"
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# 源文件列表
sources
=
[
os
.
path
.
join
(
current_dir
,
"my_ops.cpp"
),
os
.
path
.
join
(
current_dir
,
"my_ops_impl.cpp"
),
]
# 检查CUDA是否可用
use_cuda
=
torch
.
cuda
.
is_available
()
extension
=
CUDAExtension
if
use_cuda
else
CppExtension
if
use_cuda
:
import
glob
cuda_files
=
glob
.
glob
(
os
.
path
.
join
(
current_dir
,
"*.cu"
))
sources
.
extend
(
cuda_files
)
extra_compile_args
=
{
'cxx'
:
[
'-O2'
,
'-std=c++17'
],
}
if
use_cuda
:
extra_compile_args
[
'nvcc'
]
=
[
'-O2'
]
# 创建包目录和 __init__.py
package_dir
=
os
.
path
.
join
(
current_dir
,
library_name
)
os
.
makedirs
(
package_dir
,
exist_ok
=
True
)
init_py_path
=
os
.
path
.
join
(
package_dir
,
"__init__.py"
)
if
not
os
.
path
.
exists
(
init_py_path
):
with
open
(
init_py_path
,
"w"
)
as
f
:
f
.
write
(
"""
from ._C import *
__all__ = ['add_one', 'multiply_by_two']
"""
)
setup
(
name
=
library_name
,
version
=
'0.1.0'
,
description
=
'Custom operations for PyTorch'
,
author
=
'Your Name'
,
# 关键:指定包
packages
=
[
library_name
],
package_dir
=
{
library_name
:
library_name
},
# 扩展模块 - 注意命名格式
ext_modules
=
[
extension
(
name
=
f
"
{
library_name
}
._C"
,
sources
=
sources
,
extra_compile_args
=
extra_compile_args
,
include_dirs
=
[
current_dir
],
)
],
# 命令类
cmdclass
=
{
'build_ext'
:
BuildExtension
},
# 依赖
install_requires
=
[
'torch>=1.10.0'
],
# 确保生成正确的 .dist-info
zip_safe
=
False
,
)
```
#### 3.3.3 编译脚本
创建
`build.sh`
脚本:
```
bash
#!/bin/bash
rm
-rf
build
mkdir
-p
build
cd
build
cmake ..
make
-j
cp
my_ops.so ..
cd
..
```
## 4. 编译和安装
### 4.1 使用 build.sh 编译
```
bash
bash build.sh
```
### 4.2 使用 pip 安装
```
bash
pip
install
--no-build-isolation
.
```
## 5. 在 Python 中使用自定义算子
创建测试脚本
`test_ops.py`
:
```
python
import
torch
import
os
import
sys
# 添加当前目录到 Python 路径
sys
.
path
.
insert
(
0
,
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
# 加载自定义操作符
try
:
import
my_ops
print
(
"成功加载 my_ops 扩展"
)
except
ImportError
as
e
:
print
(
f
"加载扩展失败:
{
e
}
"
)
print
(
"请先编译 C++ 扩展"
)
sys
.
exit
(
1
)
def
test_add_one
():
"""测试 add_one 操作符"""
print
(
"
\n
测试 add_one 操作符:"
)
# 创建测试张量
x
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
])
print
(
f
"输入:
{
x
}
"
)
# 调用自定义操作符
y
=
torch
.
ops
.
my_ops
.
add_one
(
x
)
print
(
f
"输出:
{
y
}
"
)
# 验证结果
expected
=
x
+
1
assert
torch
.
allclose
(
y
,
expected
),
f
"结果不匹配:
{
y
}
vs
{
expected
}
"
print
(
"✓ 测试通过"
)
def
test_multiply_by_two
():
"""测试 multiply_by_two 操作符"""
print
(
"
\n
测试 multiply_by_two 操作符:"
)
# 创建测试张量
x
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
])
print
(
f
"输入:
{
x
}
"
)
# 调用自定义操作符
y
=
torch
.
ops
.
my_ops
.
multiply_by_two
(
x
)
print
(
f
"输出:
{
y
}
"
)
# 验证结果
expected
=
x
*
2
assert
torch
.
allclose
(
y
,
expected
),
f
"结果不匹配:
{
y
}
vs
{
expected
}
"
print
(
"✓ 测试通过"
)
def
test_cuda_support
():
"""测试 CUDA 支持(如果可用)"""
if
torch
.
cuda
.
is_available
():
print
(
"
\n
测试 CUDA 支持:"
)
# 创建 CUDA 张量
x
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
]).
cuda
()
print
(
f
"CUDA 输入:
{
x
}
"
)
# 测试 add_one
y
=
torch
.
ops
.
my_ops
.
add_one
(
x
)
print
(
f
"add_one 输出:
{
y
}
"
)
assert
torch
.
allclose
(
y
,
x
+
1
),
f
"CUDA add_one 结果不匹配"
# 测试 multiply_by_two
z
=
torch
.
ops
.
my_ops
.
multiply_by_two
(
x
)
print
(
f
"multiply_by_two 输出:
{
z
}
"
)
assert
torch
.
allclose
(
z
,
x
*
2
),
f
"CUDA multiply_by_two 结果不匹配"
print
(
"✓ CUDA 测试通过"
)
else
:
print
(
"
\n
CUDA 不可用,跳过 CUDA 测试"
)
if
__name__
==
"__main__"
:
print
(
"测试自定义算子"
)
print
(
"="
*
50
)
test_add_one
()
test_multiply_by_two
()
test_cuda_support
()
print
(
"
\n
"
+
"="
*
50
)
print
(
"所有测试通过!"
)
```
运行测试:
```
bash
python test_ops.py
```
## 6. 常见问题和解决方案
### 6.1 模块初始化函数错误
**错误信息**
:
```
加载扩展失败: dynamic module does not define module export function (PyInit_my_ops)
```
**解决方案**
:
确保在 C++ 文件中正确定义了
`PyInit_my_ops`
函数,并且函数名与模块名一致。
### 6.2 CUDA 支持问题
**错误信息**
:
```
undefined reference to `torch::library::Library::impl(...)`
```
**解决方案**
:
确保正确定义了
`TORCH_HAS_CUDA`
宏,并且在 CUDA 环境下编译。
### 6.3 算子注册问题
**错误信息**
:
```
undefined symbol: _ZN5torch8library7Library4implERKNS_6SymbolEPFSt10unique_ptrINS_5autograd15AutogradContextESt14default_deleteIS3_EERKSt13unordered_mapISt14basic_stringIcSt11char_traitsIcESaIcEESt6vectorINS_5IValueESaIS7_EESt4hashIS9_ESt8equal_toIS9_ESaISB_EE
```
**解决方案**
:
确保使用了正确的 PyTorch 版本,并且编译时链接了正确的 PyTorch 库。
## 7. 高级功能
### 7.1 多设备支持
可以为不同设备(CPU、CUDA)提供不同的实现:
```
cpp
// CPU 实现
TORCH_LIBRARY_IMPL
(
my_ops
,
CPU
,
m
)
{
m
.
impl
(
"add_one"
,
&
my_ops_impl
::
add_one_cpu
);
}
// CUDA 实现
#ifdef TORCH_HAS_CUDA
TORCH_LIBRARY_IMPL
(
my_ops
,
CUDA
,
m
)
{
m
.
impl
(
"add_one"
,
&
my_ops_impl
::
add_one_cuda
);
}
#endif
```
### 7.2 复杂算子定义
可以定义更复杂的算子,支持多个输入和输出:
```
cpp
TORCH_LIBRARY
(
my_ops
,
m
)
{
m
.
def
(
"add_scalar(Tensor input, Scalar scalar) -> Tensor"
);
m
.
def
(
"addmm(Tensor self, Tensor mat1, Tensor mat2, Scalar beta=1, Scalar alpha=1) -> Tensor"
);
}
```
### 7.3 算子重载
可以为不同类型的输入提供重载:
```
cpp
TORCH_LIBRARY
(
my_ops
,
m
)
{
m
.
def
(
"add(Tensor a, Tensor b) -> Tensor"
);
m
.
def
(
"add(Tensor a, Scalar b) -> Tensor"
);
}
```
## 8. 总结
使用
`TORCH_LIBRARY`
和
`TORCH_LIBRARY_IMPL`
等技术,我们可以:
1.
**定义清晰的算子接口**
:使用
`TORCH_LIBRARY`
宏定义算子的签名
2.
**实现设备特定的逻辑**
:使用
`TORCH_LIBRARY_IMPL`
为不同设备提供专门的实现
3.
**无缝集成到 PyTorch**
:通过
`PyInit_【NAME】`
函数初始化 Python 模块
4.
**灵活编译和安装**
:支持使用 CMake 或 setup.py 编译和安装
这种方法为 PyTorch 自定义算子开发提供了一种简洁、灵活的方式,使得我们可以轻松扩展 PyTorch 的功能,满足特定的业务需求。
## 9. 代码示例
完整的代码示例可以参考
`torch_custom_op`
工程中的各个部分:
-
**1part**
:展示了在单个文件中同时定义和实现算子
-
**2part**
:展示了将定义和实现分离到不同文件
-
**3part**
:展示了如何打包为可安装的 Python 包
-
**4part**
:展示了不同的模块初始化方式
通过这些示例,您可以快速上手 PyTorch 自定义算子的开发和接入。
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment