Commit 98ac1648 authored by chenyue3's avatar chenyue3
Browse files

"feat(lightop_dcu): 新增 RMSNorm+RoPE 融合算子与工程骨架" \

    -m "新增 lightop_dcu 包入口与 Python 封装接口" \
    -m "新增 csrc/export.cpp 与 fuse_rms_roped.cu 扩展实现" \
    -m "新增 setup_lightop_dcu.py,支持安装与 wheel 构建" \
    -m "补充 README 使用说明并新增 .gitignore 忽略规则"
parent 742e2e74
__pycache__/
**/__pycache__/
*.py[cod]
build/
dist/
*.egg-info/
.eggs/
pip-wheel-metadata/
*.so
*.pyd
*.dylib
*.o
*.obj
*.a
*.lib
*.ninja
.ninja_deps
.ninja_log
compile_commands.json
csrc/*.hip
csrc/*/*.hip
.vscode/
.idea/
.venv/
venv/
env/
*.log
# lightop_dcu # lightop_dcu
## 简介
`lightop_dcu` 是一个面向 DCU/ROCm 环境的轻量融合算子包,当前聚焦于 **RMSNorm + RoPE 融合前向**
当前版本提供的核心能力:
- `rms_rotary_embedding_fuse`
-`RMSNorm``Rotary Embedding` 融合在同一个自定义算子中执行
-`query` / `key` 做原地更新,减少中间访存
---
## 安装
### 环境依赖
- Python 3.10+
- PyTorch(带 ROCm/CUDA Extension 编译能力)
- 对应 DCU 驱动与编译工具链
> 说明:本仓库通过 `torch.utils.cpp_extension.CUDAExtension` 构建。
### 源码安装
在仓库目录执行:
```bash
python setup_lightop_dcu.py install
```
如果需要指定架构(示例):
```bash
PYTORCH_ROCM_ARCH='gfx906;gfx926' python setup_lightop_dcu.py install
```
### 构建 wheel
```bash
python setup_lightop_dcu.py bdist_wheel
```
构建完成后,wheel 位于 `dist/` 目录。
---
## 算子介绍
### 核心算子
| 算子 | 说明 |
| --- | --- |
| `rms_rotary_embedding_fuse` | 对 `query/key` 执行 RMSNorm 与 RoPE 融合计算(in-place) |
### Python 接口
```python
from lightop_dcu import rms_rotary_embedding_fuse
query, key = rms_rotary_embedding_fuse(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox,
weight_q,
weight_k,
residual_q,
residual_k,
epsilon=1e-5,
)
```
### 参数说明
- `positions`: `int64`,形状 `[num_tokens]``[batch_size, seq_len]`
- `query`: 浮点张量,形状 `[num_tokens, num_heads, head_size]``[batch_size, seq_len, num_heads, head_size]`
- `key`: 浮点张量,形状 `[num_tokens, num_kv_heads, head_size]``[batch_size, seq_len, num_kv_heads, head_size]`
- `head_size`: 每个 head 的维度
- `cos_sin_cache`: RoPE cache,第二维为 `rot_dim`(要求 `rot_dim <= 512`
- `is_neox`: 是否使用 GPT-NeoX 风格旋转
- `weight_q` / `weight_k`: RMSNorm 的权重
- `residual_q` / `residual_k`: 残差输入(需同时提供,或同时不提供)
- `epsilon`: RMSNorm 数值稳定项,默认 `1e-5`
### 约束与注意事项
- 算子会 **原地修改** `query``key`
- `query/key``positions` 的 token 维度必须匹配。
- `num_heads` 必须能被 `num_kv_heads` 整除。
- 当前 kernel 分支覆盖 `head_size``64/128/256/512` 的常见场景。
---
## 安装验证
```bash
python -c "import lightop_dcu; print(lightop_dcu.rms_rotary_embedding_fuse)"
```
若能正常打印函数对象,说明安装成功。
import importlib
op = importlib.import_module('.op', __name__)
from .fuse_rmsnorm_rope import rms_rotary_embedding_fuse
__all__ = [
"rms_rotary_embedding_fuse",
]
#include <torch/extension.h>
#include <optional>
using torch::Tensor;
namespace at {
namespace native {
void rms_rotary_embedding_fuse(
Tensor& positions, Tensor& query, Tensor& key, int64_t head_size,
Tensor& cos_sin_cache, bool is_neox, Tensor weight_q, Tensor weight_k,
std::optional<Tensor> residual_q, std::optional<Tensor> residual_k,
double epsilon);
} // namespace native
} // namespace at
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rms_rotary_embedding_fuse", &at::native::rms_rotary_embedding_fuse,
"rms_rotary_embedding_fuse");
}
This diff is collapsed.
import torch
from typing import Optional, Tuple
from . import op
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
weight_q: torch.Tensor,
weight_k: torch.Tensor,
residual_q: Optional[torch.Tensor],
residual_k: Optional[torch.Tensor],
epsilon: float = 1e-5,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
op.rms_rotary_embedding_fuse(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox,
weight_q,
weight_k,
residual_q,
residual_k,
epsilon,
)
return query, key
import os
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
ROOT_DIR = Path(__file__).parent.resolve()
def get_extensions():
extra_compile_args = {
"cxx": ["-O3", "-w"],
"nvcc": [
"-O3",
"-w",
"-mllvm",
"-enable-num-vgprs-512=true",
"-DHIP_ENABLE_WARP_SYNC_BUILTINS",
],
}
sources = [
str(ROOT_DIR / "csrc/export.cpp"),
str(ROOT_DIR / "csrc/fuse_rms_roped.cu"),
]
include_dirs = [str(ROOT_DIR / "csrc")]
extension = CUDAExtension(
name="lightop_dcu.op",
sources=sources,
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
)
return [extension]
setup(
name="lightop_dcu",
version=os.getenv("LIGHTOP_DCU_VERSION", "0.0.1"),
description="Minimal lightop package",
packages=["lightop_dcu"],
package_dir={"lightop_dcu": "."},
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
install_requires=["torch"],
)
#/public/home/zhuww/laibao/pkg/rms_rope_laibao_260204/lightop_dcu/README.md
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment