Commit 01e32e53 authored by liuys's avatar liuys 🏸
Browse files

update

parent 77afe6c8
*/build/*
*/build
*/*.so
*/test_ops.egg-info
*/**/*.egg-info
*/**/build
*/**/build/*
*/**/**/build/*
*/**/*.so
*/**/**/*.so
*/dist
.*/**/dist
.*/*/dist
\ No newline at end of file
cmake_minimum_required(VERSION 3.18)
project(test_torch_library_expand)
# 设置 C++ 标准
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch)
# 在文件开头添加
include_directories(/usr/include/python3.10)
# 或者更通用的方式
include_directories(/usr/include/python3.10)
# 如果有需要,还要链接Python库
link_directories(/usr/lib/python3.10/config-3.10-x86_64-linux-gnu)
# 查找并加载 Torch 库
find_package(Torch REQUIRED)
# 创建扩展库
add_library(test_ops SHARED test_torch_library_expand.cpp)
# 链接 PyTorch
target_link_libraries(test_ops PRIVATE ${TORCH_LIBRARIES})
# 设置扩展名称
set_target_properties(test_ops PROPERTIES
PREFIX ""
SUFFIX ".so"
)
# 包含 PyTorch 头文件
target_include_directories(test_ops PRIVATE ${TORCH_INCLUDE_DIRS})
# 设置 CUDA 架构(如果需要)
if (TORCH_CUDA_ARCH_LIST)
set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST})
else()
set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6")
endif()
# 打印信息
message(STATUS "PyTorch 版本: ${Torch_VERSION}")
message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}")
if (TORCH_CUDA_AVAILABLE)
message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}")
message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}")
endif()
#!/bin/bash
# 构建目录
BUILD_DIR="./build"
# 创建构建目录
mkdir -p $BUILD_DIR
cd $BUILD_DIR
# 运行 CMake
cmake ..
# 编译
cmake --build .
# 复制编译后的库到当前目录
cp test_ops.so ..
cd ..
# 运行测试
python test_torch_library_expand.py
#include <torch/library.h>
#include <ATen/ATen.h>
#include <Python.h>
#include <torch/all.h>
// A version of the TORCH_LIBRARY macro that expands the NAME
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
#define TORCH_HAS_CUDA
// 简单的自定义操作符实现
at::Tensor add_one(at::Tensor input) {
return input + 1;
}
// 另一个自定义操作符实现
at::Tensor multiply_by_two(at::Tensor input) {
return input * 2;
}
// Python模块初始化函数 - 修改这里,与目标文件名匹配
extern "C" {
PyObject *PyInit_test_ops(void) {
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"test_ops", // 模块名
"Test operations module", // 文档
-1,
NULL // 方法定义
};
return PyModule_Create(&module_def);
}
}
// 注册操作符
TORCH_LIBRARY_EXPAND(test_ops, ops) {
ops.def("add_one(Tensor input) -> Tensor");
ops.def("multiply_by_two(Tensor input) -> Tensor");
ops.impl("add_one", at::kCPU, &add_one);
ops.impl("multiply_by_two", at::kCPU, &multiply_by_two);
#ifdef TORCH_HAS_CUDA
if (torch::cuda::is_available()) {
ops.impl("add_one", at::kCUDA, &add_one);
ops.impl("multiply_by_two", at::kCUDA, &multiply_by_two);
}
#endif
}
\ No newline at end of file
import torch
import os
import sys
# 添加当前目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# 尝试加载自定义操作符
try:
# 注意:实际使用时,你需要先编译 C++ 扩展
# 这里我们假设扩展已经编译并可用
import test_ops
print("成功加载 test_ops 扩展")
except ImportError as e:
print(f"加载扩展失败: {e}")
print("请先编译 C++ 扩展")
sys.exit(1)
def test_add_one():
"""测试 add_one 操作符"""
print("\n测试 add_one 操作符:")
# 创建测试张量
x = torch.tensor([1.0, 2.0, 3.0])
print(f"输入: {x}")
# 调用自定义操作符
y = torch.ops.test_ops.add_one(x)
print(f"输出: {y}")
# 验证结果
expected = x + 1
assert torch.allclose(y, expected), f"结果不匹配: {y} vs {expected}"
print("✓ 测试通过")
def test_multiply_by_two():
"""测试 multiply_by_two 操作符"""
print("\n测试 multiply_by_two 操作符:")
# 创建测试张量
x = torch.tensor([1.0, 2.0, 3.0])
print(f"输入: {x}")
# 调用自定义操作符
y = torch.ops.test_ops.multiply_by_two(x)
print(f"输出: {y}")
# 验证结果
expected = x * 2
assert torch.allclose(y, expected), f"结果不匹配: {y} vs {expected}"
print("✓ 测试通过")
def test_cuda_support():
"""测试 CUDA 支持(如果可用)"""
if torch.cuda.is_available():
print("\n测试 CUDA 支持:")
# 创建 CUDA 张量
x = torch.tensor([1.0, 2.0, 3.0]).cuda()
print(f"CUDA 输入: {x}")
# 测试 add_one
y = torch.ops.test_ops.add_one(x)
print(f"add_one 输出: {y}")
assert torch.allclose(y, x + 1), f"CUDA add_one 结果不匹配"
# 测试 multiply_by_two
z = torch.ops.test_ops.multiply_by_two(x)
print(f"multiply_by_two 输出: {z}")
assert torch.allclose(z, x * 2), f"CUDA multiply_by_two 结果不匹配"
print("✓ CUDA 测试通过")
else:
print("\nCUDA 不可用,跳过 CUDA 测试")
if __name__ == "__main__":
print("测试 TORCH_LIBRARY_EXPAND 示例")
print("=" * 50)
test_add_one()
test_multiply_by_two()
test_cuda_support()
print("\n" + "=" * 50)
print("所有测试通过!")
cmake_minimum_required(VERSION 3.18)
project(test_torch_library_expand)
# 设置 C++ 标准
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
# 查找 Python(更通用的方式)
find_package(Python 3.10 COMPONENTS Interpreter Development REQUIRED)
set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch)
# 查找并加载 Torch 库
find_package(Torch REQUIRED)
# 创建扩展库 - 包含多个源文件
add_library(test_ops SHARED
test_torch_library_expand.cpp
test_ops_impl.cpp
)
# 链接 PyTorch 和 Python 库
target_link_libraries(test_ops PRIVATE
${TORCH_LIBRARIES}
Python::Python
)
# 设置扩展名称
set_target_properties(test_ops PROPERTIES
PREFIX ""
SUFFIX ".so"
)
# 包含头文件
target_include_directories(test_ops PRIVATE
${TORCH_INCLUDE_DIRS}
${Python_INCLUDE_DIRS}
)
# 设置 CUDA 架构(如果需要)
if (TORCH_CUDA_ARCH_LIST)
set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST})
else()
set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6")
endif()
# 打印信息
message(STATUS "PyTorch 版本: ${Torch_VERSION}")
message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}")
if (TORCH_CUDA_AVAILABLE)
message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}")
message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}")
endif()
\ No newline at end of file
Processing /data/wkx/develop/llm-infer-opt/vllm/torch_library_impl/3part
Preparing metadata (pyproject.toml): started
Preparing metadata (pyproject.toml): finished with status 'done'
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from test_ops==0.1.0) (2.5.1+das.opt1.dtk25042)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.20.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (4.15.0)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (2025.10.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->test_ops==0.1.0) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->test_ops==0.1.0) (3.0.3)
Building wheels for collected packages: test_ops
Building wheel for test_ops (pyproject.toml): started
Building wheel for test_ops (pyproject.toml): finished with status 'done'
Created wheel for test_ops: filename=test_ops-0.1.0-cp310-cp310-linux_x86_64.whl size=2414269 sha256=5d859d94506823faa6aae358f3a53d32c9fd4d452d644dd03f7197f200b28cd0
Stored in directory: /tmp/pip-ephem-wheel-cache-42ohfny6/wheels/9a/5e/82/70908f48d44dc12346f4254ac62f7b855ccc92965febfda330
Successfully built test_ops
Installing collected packages: test_ops
Attempting uninstall: test_ops
Found existing installation: test_ops 0.1.0
Can't uninstall 'test_ops'. No files were found to uninstall.
Successfully installed test_ops-0.0.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: python3 -m pip install --upgrade pip
import os
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
library_name = "test_ops"
# 获取当前目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 源文件列表
sources = [
os.path.join(current_dir, "test_torch_library_expand.cpp"),
os.path.join(current_dir, "test_ops_impl.cpp"),
]
# 检查CUDA是否可用
use_cuda = torch.cuda.is_available()
extension = CUDAExtension if use_cuda else CppExtension
if use_cuda:
# 如果有CUDA文件,可以添加
import glob
cuda_files = glob.glob(os.path.join(current_dir, "*.cu"))
sources.extend(cuda_files)
print(f"CUDA files found: {cuda_files}")
# 编译参数
extra_compile_args = {
'cxx': ['-O2', '-std=c++17'],
}
if use_cuda:
extra_compile_args['nvcc'] = ['-O2']
setup(
name=library_name,
version='0.1.0',
ext_modules=[
extension(
name=library_name,
sources=sources,
extra_compile_args=extra_compile_args,
include_dirs=[current_dir],
)
],
cmdclass={
'build_ext': BuildExtension
},
install_requires=['torch>=1.10.0'],
options={
'egg_info': {
'egg_base': '/tmp' # 将 egg-info 生成到临时目录
}
},
)
\ No newline at end of file
#include <torch/library.h>
#include <ATen/ATen.h>
#include <Python.h>
#include <torch/all.h>
#include "test_ops.h"
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define TORCH_HAS_CUDA
// Python模块初始化函数
extern "C" {
PyObject *PyInit_test_ops(void) {
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"test_ops", // 模块名
"Test operations module", // 文档
-1,
NULL // 方法定义
};
return PyModule_Create(&module_def);
}
}
// 只在TORCH_LIBRARY中定义操作符,不实现
TORCH_LIBRARY_EXPAND(test_ops, ops) {
ops.def("add_one(Tensor input) -> Tensor");
ops.def("multiply_by_two(Tensor input) -> Tensor");
}
\ No newline at end of file
import torch
import os
import sys
# 添加当前目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# 尝试加载自定义操作符
try:
# 注意:实际使用时,你需要先编译 C++ 扩展
# 这里我们假设扩展已经编译并可用
import test_ops
print("成功加载 test_ops 扩展")
except ImportError as e:
print(f"加载扩展失败: {e}")
print("请先编译 C++ 扩展")
sys.exit(1)
def test_add_one():
"""测试 add_one 操作符"""
print("\n测试 add_one 操作符:")
# 创建测试张量
x = torch.tensor([1.0, 2.0, 3.0])
print(f"输入: {x}")
# 调用自定义操作符
y = torch.ops.test_ops.add_one(x)
print(f"输出: {y}")
# 验证结果
expected = x + 1
assert torch.allclose(y, expected), f"结果不匹配: {y} vs {expected}"
print("✓ 测试通过")
def test_multiply_by_two():
"""测试 multiply_by_two 操作符"""
print("\n测试 multiply_by_two 操作符:")
# 创建测试张量
x = torch.tensor([1.0, 2.0, 3.0])
print(f"输入: {x}")
# 调用自定义操作符
y = torch.ops.test_ops.multiply_by_two(x)
print(f"输出: {y}")
# 验证结果
expected = x * 2
assert torch.allclose(y, expected), f"结果不匹配: {y} vs {expected}"
print("✓ 测试通过")
def test_cuda_support():
"""测试 CUDA 支持(如果可用)"""
if torch.cuda.is_available():
print("\n测试 CUDA 支持:")
# 创建 CUDA 张量
x = torch.tensor([1.0, 2.0, 3.0]).cuda()
print(f"CUDA 输入: {x}")
# 测试 add_one
y = torch.ops.test_ops.add_one(x)
print(f"add_one 输出: {y}")
assert torch.allclose(y, x + 1), f"CUDA add_one 结果不匹配"
# 测试 multiply_by_two
z = torch.ops.test_ops.multiply_by_two(x)
print(f"multiply_by_two 输出: {z}")
assert torch.allclose(z, x * 2), f"CUDA multiply_by_two 结果不匹配"
print("✓ CUDA 测试通过")
else:
print("\nCUDA 不可用,跳过 CUDA 测试")
if __name__ == "__main__":
print("测试 TORCH_LIBRARY_EXPAND 示例")
print("=" * 50)
test_add_one()
test_multiply_by_two()
test_cuda_support()
print("\n" + "=" * 50)
print("所有测试通过!")
#!/bin/bash
# 构建目录
BUILD_DIR="./build"
# 创建构建目录
mkdir -p $BUILD_DIR
cd $BUILD_DIR
# 运行 CMake
cmake ..
# 编译
cmake --build .
# 复制编译后的库到当前目录
cp test_ops.so ..
cd ..
# 运行测试
python test_torch_library_expand.py
#pragma once
#include <torch/library.h>
#include <ATen/ATen.h>
// 声明操作符实现函数
namespace test_ops_impl {
at::Tensor add_one(at::Tensor input);
at::Tensor multiply_by_two(at::Tensor input);
}
\ No newline at end of file
#include <torch/library.h>
#include <ATen/ATen.h>
#include "test_ops.h"
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
#define TORCH_HAS_CUDA
namespace test_ops_impl {
// 操作符的具体实现
at::Tensor add_one(at::Tensor input) {
return input + 1;
}
at::Tensor multiply_by_two(at::Tensor input) {
return input * 2;
}
}
// 在TORCH_LIBRARY_IMPL中注册CPU实现
TORCH_LIBRARY_IMPL_EXPAND(test_ops, CPU, cpu_ops) {
cpu_ops.impl("add_one", &test_ops_impl::add_one);
cpu_ops.impl("multiply_by_two", &test_ops_impl::multiply_by_two);
}
// 在TORCH_LIBRARY_IMPL中注册CUDA实现(如果有CUDA)
#ifdef TORCH_HAS_CUDA
TORCH_LIBRARY_IMPL_EXPAND(test_ops, CUDA, cuda_ops) {
// 注意:这里假设CPU和CUDA使用相同的实现函数
// 如果CUDA需要不同的实现,可以定义专门的CUDA版本函数
cuda_ops.impl("add_one", &test_ops_impl::add_one);
cuda_ops.impl("multiply_by_two", &test_ops_impl::multiply_by_two);
}
#endif
\ No newline at end of file
cmake_minimum_required(VERSION 3.18)
project(test_torch_library_expand)
# 设置 C++ 标准
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
# 查找 Python(更通用的方式)
find_package(Python 3.10 COMPONENTS Interpreter Development REQUIRED)
set(Torch_DIR /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch)
# 查找并加载 Torch 库
find_package(Torch REQUIRED)
# 创建扩展库 - 包含多个源文件
add_library(test_ops SHARED
test_torch_library_expand.cpp
test_ops_impl.cpp
)
# 链接 PyTorch 和 Python 库
target_link_libraries(test_ops PRIVATE
${TORCH_LIBRARIES}
Python::Python
)
# 设置扩展名称
set_target_properties(test_ops PROPERTIES
PREFIX ""
SUFFIX ".so"
)
# 包含头文件
target_include_directories(test_ops PRIVATE
${TORCH_INCLUDE_DIRS}
${Python_INCLUDE_DIRS}
)
# 设置 CUDA 架构(如果需要)
if (TORCH_CUDA_ARCH_LIST)
set(CUDA_ARCH_LIST ${TORCH_CUDA_ARCH_LIST})
else()
set(CUDA_ARCH_LIST "6.0;6.1;7.0;7.5;8.0;8.6")
endif()
# 打印信息
message(STATUS "PyTorch 版本: ${Torch_VERSION}")
message(STATUS "CUDA 可用: ${TORCH_CUDA_AVAILABLE}")
if (TORCH_CUDA_AVAILABLE)
message(STATUS "CUDA 版本: ${CUDA_VERSION_STRING}")
message(STATUS "CUDA 架构: ${CUDA_ARCH_LIST}")
endif()
\ No newline at end of file
#!/bin/bash
# 构建目录
BUILD_DIR="./build"
# 创建构建目录
mkdir -p $BUILD_DIR
cd $BUILD_DIR
# 运行 CMake
cmake ..
# 编译
cmake --build .
# 复制编译后的库到当前目录
cp test_ops.so ..
cd ..
# 运行测试
python test_torch_library_expand.py
Processing /data/wkx/develop/llm-infer-opt/vllm/workspace_4part
Preparing metadata (pyproject.toml): started
Preparing metadata (pyproject.toml): finished with status 'done'
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from test_ops==0.1.0) (2.5.1+das.opt1.dtk25042)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.20.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (4.15.0)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (2025.10.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->test_ops==0.1.0) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->test_ops==0.1.0) (3.0.3)
Building wheels for collected packages: test_ops
Building wheel for test_ops (pyproject.toml): started
Building wheel for test_ops (pyproject.toml): finished with status 'done'
Created wheel for test_ops: filename=test_ops-0.1.0-cp310-cp310-linux_x86_64.whl size=2411654 sha256=6a501c3539b689e504aee2a10cfe217131ba065c353c5670ae987181b4a19390
Stored in directory: /tmp/pip-ephem-wheel-cache-iorcxf73/wheels/c0/05/cc/eead000af8b8cafeb5b86d18cf5c5281da267de35757b7851d
Successfully built test_ops
Installing collected packages: test_ops
Attempting uninstall: test_ops
Found existing installation: test_ops 0.1.0
Can't uninstall 'test_ops'. No files were found to uninstall.
Successfully installed test_ops-0.0.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: python3 -m pip install --upgrade pip
import os
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
library_name = "test_ops"
# 获取当前目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 源文件列表
sources = [
os.path.join(current_dir, "test_torch_library_expand.cpp"),
os.path.join(current_dir, "test_ops_impl.cpp"),
]
# 检查CUDA是否可用
use_cuda = torch.cuda.is_available()
extension = CUDAExtension if use_cuda else CppExtension
if use_cuda:
# 如果有CUDA文件,可以添加
import glob
cuda_files = glob.glob(os.path.join(current_dir, "*.cu"))
sources.extend(cuda_files)
print(f"CUDA files found: {cuda_files}")
# 编译参数
extra_compile_args = {
'cxx': ['-O2', '-std=c++17'],
}
if use_cuda:
extra_compile_args['nvcc'] = ['-O2']
setup(
name=library_name,
version='0.1.0',
ext_modules=[
extension(
name=library_name,
sources=sources,
extra_compile_args=extra_compile_args,
include_dirs=[current_dir],
)
],
cmdclass={
'build_ext': BuildExtension
},
install_requires=['torch>=1.10.0'],
options={
'egg_info': {
'egg_base': '/tmp' # 将 egg-info 生成到临时目录
}
},
)
\ No newline at end of file
#pragma once
#include <torch/library.h>
#include <ATen/ATen.h>
// 声明操作符实现函数
namespace test_ops_impl {
at::Tensor add_one(at::Tensor input);
at::Tensor multiply_by_two(at::Tensor input);
}
\ No newline at end of file
#include <torch/library.h>
#include <ATen/ATen.h>
#include "test_ops.h"
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
#define TORCH_HAS_CUDA
namespace test_ops_impl {
// 操作符的具体实现
at::Tensor add_one(at::Tensor input) {
return input + 1;
}
at::Tensor multiply_by_two(at::Tensor input) {
return input * 2;
}
}
// 在TORCH_LIBRARY_IMPL中注册CPU实现
TORCH_LIBRARY_IMPL_EXPAND(test_ops, CPU, cpu_ops) {
cpu_ops.impl("add_one", &test_ops_impl::add_one);
cpu_ops.impl("multiply_by_two", &test_ops_impl::multiply_by_two);
}
// 在TORCH_LIBRARY_IMPL中注册CUDA实现(如果有CUDA)
#ifdef TORCH_HAS_CUDA
TORCH_LIBRARY_IMPL_EXPAND(test_ops, CUDA, cuda_ops) {
// 注意:这里假设CPU和CUDA使用相同的实现函数
// 如果CUDA需要不同的实现,可以定义专门的CUDA版本函数
cuda_ops.impl("add_one", &test_ops_impl::add_one);
cuda_ops.impl("multiply_by_two", &test_ops_impl::multiply_by_two);
}
#endif
\ No newline at end of file
#include <torch/library.h>
#include <ATen/ATen.h>
#include <Python.h>
#include <torch/all.h>
#include "test_ops.h"
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define TORCH_HAS_CUDA
// // Python模块初始化函数
// extern "C" {
// PyObject *PyInit_test_ops(void) {
// static struct PyModuleDef module_def = {
// PyModuleDef_HEAD_INIT,
// "test_ops", // 模块名
// "Test operations module", // 文档
// -1,
// NULL // 方法定义
// };
// return PyModule_Create(&module_def);
// }
// }
// 只在TORCH_LIBRARY中定义操作符,不实现
TORCH_LIBRARY_EXPAND(test_ops, ops) {
ops.def("add_one(Tensor input) -> Tensor");
ops.def("multiply_by_two(Tensor input) -> Tensor");
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment