Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MLAttention
Commits
ebcba9f0
Commit
ebcba9f0
authored
Feb 26, 2025
by
zhangqha
Browse files
support MLAttention
parent
cb13d4d8
Pipeline
#2431
canceled with stages
Changes
5
Pipelines
1
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1381 additions
and
0 deletions
+1381
-0
README.md
README.md
+54
-0
setup.py
setup.py
+12
-0
tests/test_triton_decode_attention.py
tests/test_triton_decode_attention.py
+90
-0
triton_mla_op/__init__.py
triton_mla_op/__init__.py
+3
-0
triton_mla_op/triton_decode_attention.py
triton_mla_op/triton_decode_attention.py
+1222
-0
No files found.
README.md
View file @
ebcba9f0
# MLAttention
# MLAttention
## 简介
```
MLAttention is an efficient MLA decoding kernel , optimized for variable-length sequences serving.
目前支持的精度:
- BF16, FP16
目前支持的实现方式:
- OpenAI Triton
```
## 安装
### 源码方式安装
```
bash
python3
-m
pip
install
.
```
### 单测验证
```
bash
pytest
-s
tests/test_triton_decode_attention.py
```
## 使用方式
```
python
import
triton
from
triton_mla_op.triton_decode_attention
import
decode_attention_fwd
...
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
...
```
## MLAttention开发进度
```
目前,基于 Cutlass的MLAttention版本正在积极开发中。
我们会及时在项目仓库中更新开发进度。欢迎关注我们的开发者社区以获取最新信息。
```
setup.py
0 → 100644
View file @
ebcba9f0
from
setuptools
import
setup
,
find_packages
setup
(
name
=
'MLAttention'
,
packages
=
find_packages
(
"triton_mla_op"
),
package_dir
=
{
""
:
"triton_mla_op"
},
include_package_data
=
True
,
install_requires
=
[
'triton'
,
'torch'
]
)
tests/test_triton_decode_attention.py
0 → 100644
View file @
ebcba9f0
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
import
triton
from
triton_mla_op.triton_decode_attention
import
decode_attention_fwd
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1027
,
1025
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D_QK"
,
[
128
,
192
,
576
])
@
pytest
.
mark
.
parametrize
(
"D_V"
,
[
128
,
512
])
@
pytest
.
mark
.
parametrize
(
"CACHE_SIZE"
,
[
16384
])
@
pytest
.
mark
.
parametrize
(
"PAGE_SIZE"
,
[
1
,
16
])
def
test_decode_attention
(
B
,
L
,
H_Q
,
H_KV
,
D_QK
,
D_V
,
CACHE_SIZE
,
PAGE_SIZE
):
assert
CACHE_SIZE
%
PAGE_SIZE
==
0
dtype
=
torch
.
bfloat16
seq_len
=
L
# This represents the number of tokens already in the sequence
sm_scale
=
1.0
/
(
D_QK
**
0.5
)
num_kv_splits
=
8
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
#这里为向上取整,65,(1027+16-1)//16
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
#shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device
=
"cuda"
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
#这里是维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
[:,
:
seq_len
].
contiguous
()
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,
),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
# Call the original implementation.
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
# Page size can be larger than 1.
k_buffer
=
k_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_QK
)
v_buffer
=
v_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_V
)
o1
=
torch
.
zeros_like
(
o
)
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o1
,
req_to_page
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
PAGE_SIZE
,
)
assert
torch
.
allclose
(
o
,
o1
)
triton_mla_op/__init__.py
0 → 100644
View file @
ebcba9f0
from
.triton_decode_attention
import
decode_attention_fwd
__all__
=
[
'decode_attention_fwd'
]
triton_mla_op/triton_decode_attention.py
0 → 100644
View file @
ebcba9f0
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment