README.md 2.45 KB
Newer Older
1
# lmcustomop
laibao's avatar
laibao committed
2

3
4
## 简介

5
`lmcustomop` 是一个面向 DCU/ROCm 环境的轻量融合算子包,当前聚焦于 **RMSNorm + RoPE 融合前向**
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

当前版本提供的核心能力:

- `rms_rotary_embedding_fuse`
-`RMSNorm``Rotary Embedding` 融合在同一个自定义算子中执行
-`query` / `key` 做原地更新,减少中间访存

---

## 安装

### 环境依赖

- Python 3.10+
- PyTorch(带 ROCm/CUDA Extension 编译能力)
- 对应 DCU 驱动与编译工具链

> 说明:本仓库通过 `torch.utils.cpp_extension.CUDAExtension` 构建。

### 源码安装

在仓库目录执行:

```bash
30
python setup_lmcustomop.py install
31
32
33
34
35
```

如果需要指定架构(示例):

```bash
36
PYTORCH_ROCM_ARCH='gfx906;gfx926' python setup_lmcustomop.py install
37
38
39
40
41
```

### 构建 wheel

```bash
42
python setup_lmcustomop.py bdist_wheel
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
```

构建完成后,wheel 位于 `dist/` 目录。

---

## 算子介绍

### 核心算子

| 算子 | 说明 |
| --- | --- |
| `rms_rotary_embedding_fuse` | 对 `query/key` 执行 RMSNorm 与 RoPE 融合计算(in-place) |

### Python 接口

```python
60
from lmcustomop import rms_rotary_embedding_fuse
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

query, key = rms_rotary_embedding_fuse(
    positions,
    query,
    key,
    head_size,
    cos_sin_cache,
    is_neox,
    weight_q,
    weight_k,
    residual_q,
    residual_k,
    epsilon=1e-5,
)
```

### 参数说明

- `positions`: `int64`,形状 `[num_tokens]``[batch_size, seq_len]`
- `query`: 浮点张量,形状 `[num_tokens, num_heads, head_size]``[batch_size, seq_len, num_heads, head_size]`
- `key`: 浮点张量,形状 `[num_tokens, num_kv_heads, head_size]``[batch_size, seq_len, num_kv_heads, head_size]`
- `head_size`: 每个 head 的维度
- `cos_sin_cache`: RoPE cache,第二维为 `rot_dim`(要求 `rot_dim <= 512`
- `is_neox`: 是否使用 GPT-NeoX 风格旋转
- `weight_q` / `weight_k`: RMSNorm 的权重
- `residual_q` / `residual_k`: 残差输入(需同时提供,或同时不提供)
- `epsilon`: RMSNorm 数值稳定项,默认 `1e-5`

### 约束与注意事项

- 算子会 **原地修改** `query``key`
- `query/key``positions` 的 token 维度必须匹配。
- `num_heads` 必须能被 `num_kv_heads` 整除。
- 当前 kernel 分支覆盖 `head_size``64/128/256/512` 的常见场景。

---

## 安装验证

```bash
101
python -c "import lmcustomop; print(lmcustomop.rms_rotary_embedding_fuse)"
102
103
104
```

若能正常打印函数对象,说明安装成功。