Commit a22ec42e authored by wangkaixiong's avatar wangkaixiong 🚴🏼
Browse files

update setup.py

parent 12a8520e
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
*/build */build
*/*.so */*.so
*/test_ops.egg-info */test_ops.egg-info
*/dist
...@@ -48,4 +48,9 @@ setup( ...@@ -48,4 +48,9 @@ setup(
'build_ext': BuildExtension 'build_ext': BuildExtension
}, },
install_requires=['torch>=1.10.0'], install_requires=['torch>=1.10.0'],
options={
'egg_info': {
'egg_base': '/tmp' # 将 egg-info 生成到临时目录
}
},
) )
\ No newline at end of file
Processing /data/wkx/develop/llm-infer-opt/vllm/torch_library_impl/2part
Preparing metadata (pyproject.toml): started
Preparing metadata (pyproject.toml): finished with status 'done'
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from test_ops==0.1.0) (2.5.1+das.opt1.dtk25042)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.20.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (4.15.0)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (2025.10.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->test_ops==0.1.0) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->test_ops==0.1.0) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->test_ops==0.1.0) (3.0.3)
Building wheels for collected packages: test_ops
Building wheel for test_ops (pyproject.toml): started
Building wheel for test_ops (pyproject.toml): finished with status 'done'
Created wheel for test_ops: filename=test_ops-0.1.0-cp310-cp310-linux_x86_64.whl size=2412250 sha256=23a8e05cc60f2abe8c8e378ccb4fca84bce4a149c1e92aa6822df953351f1989
Stored in directory: /tmp/pip-ephem-wheel-cache-x6afqwi7/wheels/c0/4a/97/20d696f54d65e72f1150ebdb5f11b99879a886c7ee012ed307
Successfully built test_ops
Installing collected packages: test_ops
Attempting uninstall: test_ops
Found existing installation: test_ops 0.1.0
Can't uninstall 'test_ops'. No files were found to uninstall.
Successfully installed test_ops-0.1.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: python3 -m pip install --upgrade pip
import os import os
import torch import torch
from setuptools import setup from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
library_name = "test_ops" library_name = "test_ops"
# 获取当前目录
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
# 源文件列表 # 源文件列表
...@@ -19,13 +17,10 @@ use_cuda = torch.cuda.is_available() ...@@ -19,13 +17,10 @@ use_cuda = torch.cuda.is_available()
extension = CUDAExtension if use_cuda else CppExtension extension = CUDAExtension if use_cuda else CppExtension
if use_cuda: if use_cuda:
# 如果有CUDA文件,可以添加
import glob import glob
cuda_files = glob.glob(os.path.join(current_dir, "*.cu")) cuda_files = glob.glob(os.path.join(current_dir, "*.cu"))
sources.extend(cuda_files) sources.extend(cuda_files)
print(f"CUDA files found: {cuda_files}")
# 编译参数
extra_compile_args = { extra_compile_args = {
'cxx': ['-O2', '-std=c++17'], 'cxx': ['-O2', '-std=c++17'],
} }
...@@ -33,19 +28,53 @@ extra_compile_args = { ...@@ -33,19 +28,53 @@ extra_compile_args = {
if use_cuda: if use_cuda:
extra_compile_args['nvcc'] = ['-O2'] extra_compile_args['nvcc'] = ['-O2']
# 创建包目录和 __init__.py
package_dir = os.path.join(current_dir, library_name)
os.makedirs(package_dir, exist_ok=True)
init_py_path = os.path.join(package_dir, "__init__.py")
if not os.path.exists(init_py_path):
with open(init_py_path, "w") as f:
f.write("""
from ._C import *
__all__ = ['add_one', 'multiply_by_two']
""")
setup( setup(
name=library_name, name=library_name,
version='0.1.0', version='0.1.1',
description='Test operations for PyTorch',
author='Your Name',
# 关键:指定包
packages=[library_name],
package_dir={library_name: library_name},
# 扩展模块 - 注意命名格式
ext_modules=[ ext_modules=[
extension( extension(
name=library_name + "._C", name=f"{library_name}._C",
sources=sources, sources=sources,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
include_dirs=[current_dir], include_dirs=[current_dir],
) )
], ],
# 命令类
cmdclass={ cmdclass={
'build_ext': BuildExtension 'build_ext': BuildExtension
}, },
# 依赖
install_requires=['torch>=1.10.0'], install_requires=['torch>=1.10.0'],
# 确保生成正确的 .dist-info
zip_safe=False,
# 添加以下参数来避免生成 .egg-info 在当前目录
options={
'egg_info': {
'egg_base': '/tmp' # 将 egg-info 生成到临时目录
}
},
) )
\ No newline at end of file
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
import torch import torch
from ._C import * from ._C import *
# 导出函数
__all__ = ['add_one', 'multiply_by_two']
# 注册操作 # 注册操作
def add_one(input): def add_one(input):
return torch.ops.test_ops.add_one(input) return torch.ops.test_ops.add_one(input)
def multiply_by_two(input): def multiply_by_two(input):
return torch.ops.test_ops.multiply_by_two(input) return torch.ops.test_ops.multiply_by_two(input)
\ No newline at end of file
# 导出函数
__all__ = ['add_one', 'multiply_by_two']
\ No newline at end of file
...@@ -48,4 +48,9 @@ setup( ...@@ -48,4 +48,9 @@ setup(
'build_ext': BuildExtension 'build_ext': BuildExtension
}, },
install_requires=['torch>=1.10.0'], install_requires=['torch>=1.10.0'],
options={
'egg_info': {
'egg_base': '/tmp' # 将 egg-info 生成到临时目录
}
},
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment