Commit ebcba9f0 authored by zhangqha's avatar zhangqha
Browse files

support MLAttention

parent cb13d4d8
Pipeline #2431 canceled with stages
# MLAttention # MLAttention
## 简介
```
MLAttention is an efficient MLA decoding kernel , optimized for variable-length sequences serving.
目前支持的精度:
- BF16, FP16
目前支持的实现方式:
- OpenAI Triton
```
## 安装
### 源码方式安装
```bash
python3 -m pip install .
```
### 单测验证
```bash
pytest -s tests/test_triton_decode_attention.py
```
## 使用方式
```python
import triton
from triton_mla_op.triton_decode_attention import decode_attention_fwd
...
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
...
```
## MLAttention开发进度
```
目前,基于 Cutlass的MLAttention版本正在积极开发中。
我们会及时在项目仓库中更新开发进度。欢迎关注我们的开发者社区以获取最新信息。
```
from setuptools import setup, find_packages
setup(
name='MLAttention',
packages=find_packages("triton_mla_op"),
package_dir={"": "triton_mla_op"},
include_package_data=True,
install_requires=[
'triton',
'torch'
]
)
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import triton
from triton_mla_op.triton_decode_attention import decode_attention_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 192, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) #这里为向上取整,65,(1027+16-1)//16
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) #这里是维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
# Call the original implementation.
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1)
from .triton_decode_attention import decode_attention_fwd
__all__ = ['decode_attention_fwd']
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment