Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e4a9c2cd
Commit
e4a9c2cd
authored
Mar 06, 2025
by
zhuwenwen
Browse files
update mla optest
parent
52121d00
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
110 additions
and
26 deletions
+110
-26
setup.py
setup.py
+1
-1
tests/kernels/test_triton_decode_attention.py
tests/kernels/test_triton_decode_attention.py
+91
-7
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+18
-18
No files found.
setup.py
View file @
e4a9c2cd
...
@@ -688,7 +688,7 @@ package_data = {
...
@@ -688,7 +688,7 @@ package_data = {
"model_executor/layers/fused_moe/configs/*.json"
,
"model_executor/layers/fused_moe/configs/*.json"
,
"model_executor/layers/quantization/utils/configs/*.json"
,
"model_executor/layers/quantization/utils/configs/*.json"
,
"benchmarks/*.py"
,
"benchmarks/*.py"
,
"
model_executor/layers/quantization
/configs/
w8a8/
*.json"
,
"
attention/backends
/configs/*.json"
,
"model_executor/layers/quantization/configs/awq/*.json"
"model_executor/layers/quantization/configs/awq/*.json"
]
]
}
}
...
...
tests/kernels/test_triton_decode_attention.py
View file @
e4a9c2cd
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
import
pytest
import
pytest
import
torch
import
torch
import
triton
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
,
decode_attention_v1
,
decode_attention_v2
def
cdiv
(
a
,
b
):
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
return
(
a
+
b
-
1
)
//
b
...
@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
...
@@ -25,13 +25,13 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale
=
1.0
/
(
D_QK
**
0.5
)
sm_scale
=
1.0
/
(
D_QK
**
0.5
)
num_kv_splits
=
8
num_kv_splits
=
8
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
# 向上取整:65, (1027+16-1)//16
req_to_page
=
torch
.
randint
(
0
,
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
(
B
,
num_pages_per_batch
,
1
),
#shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device
=
"cuda"
)
device
=
"cuda"
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
# 维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
...
@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
...
@@ -47,14 +47,22 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,
),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,
),
seq_len
,
device
=
"cuda"
)
b_start_loc
=
torch
.
arange
(
0
,
k_buffer
.
shape
[
0
]
*
PAGE_SIZE
,
k_buffer
.
shape
[
0
]
*
PAGE_SIZE
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
attn_logits_v1
=
torch
.
empty
(
(
q
.
shape
[
1
],
k_buffer
.
shape
[
0
]
*
PAGE_SIZE
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
# Call the original implementation.
# Call the original implementation.
decode_attention_fwd
(
decode_attention_fwd
(
...
@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
...
@@ -87,5 +95,81 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
sm_scale
,
sm_scale
,
PAGE_SIZE
,
PAGE_SIZE
,
)
)
assert
torch
.
allclose
(
o
,
o1
)
assert
torch
.
allclose
(
o
,
o1
)
# v0_tc_ms, v0_tc_min_ms, v0_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_fwd(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention ori kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v0_tc_ms)
decode_attention_v1
(
q
,
k_buffer
,
v_buffer
,
o1
,
req_to_page
,
b_start_loc
,
b_seq_len
,
attn_logits_v1
,
num_kv_splits
,
sm_scale
,
PAGE_SIZE
,
)
assert
torch
.
allclose
(
o
,
o1
,
atol
=
1e-2
,
rtol
=
1e-2
)
# v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o1
,
req_to_page
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
PAGE_SIZE
,
)
assert
torch
.
allclose
(
o
,
o1
,
atol
=
1e-2
,
rtol
=
1e-2
)
# v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o1,
# req_to_page,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# PAGE_SIZE,
# ), quantiles=quantiles)
# print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
vllm/attention/ops/triton_decode_attention.py
View file @
e4a9c2cd
...
@@ -1420,7 +1420,7 @@ def decode_attention_v2(
...
@@ -1420,7 +1420,7 @@ def decode_attention_v2(
_decode_v2_stage2_best_config
=
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
_decode_v2_stage2_best_config
=
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
return
_decode_v2_stage1_best_config
,
_decode_v2_stage2_best_config
return
_decode_v2_stage1_best_config
,
_decode_v2_stage2_best_config
def
decode_attention_fwd
(
def
decode_attention_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -1455,21 +1455,7 @@ def decode_attention_fwd(
...
@@ -1455,21 +1455,7 @@ def decode_attention_fwd(
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
if
not
envs
.
VLLM_USE_TRITON_OPT_MLA
:
if
envs
.
VLLM_USE_TRITON_OPT_MLA
:
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
else
:
decode_attention_v2
(
decode_attention_v2
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -1501,7 +1487,7 @@ def decode_attention_fwd(
...
@@ -1501,7 +1487,7 @@ def decode_attention_fwd(
# page_size,
# page_size,
# logit_cap,
# logit_cap,
# )
# )
# if best_config['kernel_kind'] == 'v1_2stages_tc':
# if best_config['kernel_kind'] == 'v1_2stages_tc':
# attn_logits_v1 = torch.empty(
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# (q.shape[1],k_buffer.shape[0]*page_size),
...
@@ -1538,4 +1524,18 @@ def decode_attention_fwd(
...
@@ -1538,4 +1524,18 @@ def decode_attention_fwd(
# logit_cap,
# logit_cap,
# )
# )
# else:
# else:
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
\ No newline at end of file
else
:
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment