README.md 2.45 KB
Newer Older
laibao's avatar
laibao committed
1
2
# lightop_dcu

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
## 简介

`lightop_dcu` 是一个面向 DCU/ROCm 环境的轻量融合算子包,当前聚焦于 **RMSNorm + RoPE 融合前向**

当前版本提供的核心能力:

- `rms_rotary_embedding_fuse`
-`RMSNorm``Rotary Embedding` 融合在同一个自定义算子中执行
-`query` / `key` 做原地更新,减少中间访存

---

## 安装

### 环境依赖

- Python 3.10+
- PyTorch(带 ROCm/CUDA Extension 编译能力)
- 对应 DCU 驱动与编译工具链

> 说明:本仓库通过 `torch.utils.cpp_extension.CUDAExtension` 构建。

### 源码安装

在仓库目录执行:

```bash
python setup_lightop_dcu.py install
```

如果需要指定架构(示例):

```bash
PYTORCH_ROCM_ARCH='gfx906;gfx926' python setup_lightop_dcu.py install
```

### 构建 wheel

```bash
python setup_lightop_dcu.py bdist_wheel
```

构建完成后,wheel 位于 `dist/` 目录。

---

## 算子介绍

### 核心算子

| 算子 | 说明 |
| --- | --- |
| `rms_rotary_embedding_fuse` | 对 `query/key` 执行 RMSNorm 与 RoPE 融合计算(in-place) |

### Python 接口

```python
from lightop_dcu import rms_rotary_embedding_fuse

query, key = rms_rotary_embedding_fuse(
    positions,
    query,
    key,
    head_size,
    cos_sin_cache,
    is_neox,
    weight_q,
    weight_k,
    residual_q,
    residual_k,
    epsilon=1e-5,
)
```

### 参数说明

- `positions`: `int64`,形状 `[num_tokens]``[batch_size, seq_len]`
- `query`: 浮点张量,形状 `[num_tokens, num_heads, head_size]``[batch_size, seq_len, num_heads, head_size]`
- `key`: 浮点张量,形状 `[num_tokens, num_kv_heads, head_size]``[batch_size, seq_len, num_kv_heads, head_size]`
- `head_size`: 每个 head 的维度
- `cos_sin_cache`: RoPE cache,第二维为 `rot_dim`(要求 `rot_dim <= 512`
- `is_neox`: 是否使用 GPT-NeoX 风格旋转
- `weight_q` / `weight_k`: RMSNorm 的权重
- `residual_q` / `residual_k`: 残差输入(需同时提供,或同时不提供)
- `epsilon`: RMSNorm 数值稳定项,默认 `1e-5`

### 约束与注意事项

- 算子会 **原地修改** `query``key`
- `query/key``positions` 的 token 维度必须匹配。
- `num_heads` 必须能被 `num_kv_heads` 整除。
- 当前 kernel 分支覆盖 `head_size``64/128/256/512` 的常见场景。

---

## 安装验证

```bash
python -c "import lightop_dcu; print(lightop_dcu.rms_rotary_embedding_fuse)"
```

若能正常打印函数对象,说明安装成功。