Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8b5a09f6
Commit
8b5a09f6
authored
Apr 08, 2025
by
zhuwenwen
Browse files
update triton_decode_attention.py
parent
f6044f1a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
38 deletions
+62
-38
examples/mla/triton_decode_attention.py
examples/mla/triton_decode_attention.py
+21
-10
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+41
-28
No files found.
examples/mla/triton_decode_attention.py
View file @
8b5a09f6
...
@@ -40,6 +40,7 @@ is_hip_ = current_platform.is_rocm()
...
@@ -40,6 +40,7 @@ is_hip_ = current_platform.is_rocm()
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"1"
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"1"
os
.
environ
[
"TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"
]
=
"0"
os
.
environ
[
"TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"
]
=
"0"
os
.
environ
[
"TRITON_DEFAULT_ENABLE_NUM_VGPRS512"
]
=
"1"
os
.
environ
[
"TRITON_DEFAULT_ENABLE_NUM_VGPRS512"
]
=
"1"
os
.
environ
[
"MLIR_ENABLE_DUMP"
]
=
"0"
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -830,15 +831,15 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -830,15 +831,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if
BLOCK_DPE
>
0
:
#
if BLOCK_DPE > 0:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
#
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe
=
offs_dpe
<
Lk
#
mask_dpe = offs_dpe < Lk
off_qpe
=
(
#
off_qpe = (
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
#
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
#
)
qpe
=
tl
.
load
(
#
qpe = tl.load(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
#
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
#
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
...
@@ -867,11 +868,19 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -867,11 +868,19 @@ def _decode_v2_kernel_stage1_use_tc(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
offs_buf_kpe
=
(
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
+
offs_dpe
[:,
None
]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kpe
=
tl
.
load
(
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
...
@@ -1009,7 +1018,6 @@ def _decode_v2_stage1_use_tc(
...
@@ -1009,7 +1018,6 @@ def _decode_v2_stage1_use_tc(
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
...
@@ -1163,6 +1171,9 @@ def decode_attentionv2_fwd(
...
@@ -1163,6 +1171,9 @@ def decode_attentionv2_fwd(
#[TODO] The relationship between L and block is to be analyzed
#[TODO] The relationship between L and block is to be analyzed
if
L
>=
2048
:
if
L
>=
2048
:
num_kv_splits
=
(
2
*
cu_num
-
1
+
grid_num
)
//
grid_num
num_kv_splits
=
(
2
*
cu_num
-
1
+
grid_num
)
//
grid_num
if
L
>=
4096
:
num_kv_splits
=
(
4
*
cu_num
-
1
+
grid_num
)
//
grid_num
attn_logits_v1
=
torch
.
empty
(
attn_logits_v1
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
...
vllm/attention/ops/triton_decode_attention.py
View file @
8b5a09f6
...
@@ -42,6 +42,7 @@ is_hip_ = current_platform.is_rocm()
...
@@ -42,6 +42,7 @@ is_hip_ = current_platform.is_rocm()
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"1"
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"1"
os
.
environ
[
"TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"
]
=
"0"
os
.
environ
[
"TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"
]
=
"0"
os
.
environ
[
"TRITON_DEFAULT_ENABLE_NUM_VGPRS512"
]
=
"1"
os
.
environ
[
"TRITON_DEFAULT_ENABLE_NUM_VGPRS512"
]
=
"1"
os
.
environ
[
"MLIR_ENABLE_DUMP"
]
=
"0"
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1141,15 +1142,15 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -1141,15 +1142,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if
BLOCK_DPE
>
0
:
#
if BLOCK_DPE > 0:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
#
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe
=
offs_dpe
<
Lk
#
mask_dpe = offs_dpe < Lk
off_qpe
=
(
#
off_qpe = (
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
#
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
#
)
qpe
=
tl
.
load
(
#
qpe = tl.load(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
#
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
#
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
...
@@ -1179,11 +1180,21 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -1179,11 +1180,21 @@ def _decode_v2_kernel_stage1_use_tc(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
offs_buf_kpe
=
(
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
+
offs_dpe
[:,
None
]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kpe
=
tl
.
load
(
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
...
@@ -1335,7 +1346,6 @@ def _decode_v2_stage1_use_tc(
...
@@ -1335,7 +1346,6 @@ def _decode_v2_stage1_use_tc(
# @triton.autotune(
# @triton.autotune(
# configs=[
# configs=[
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1),
...
@@ -1500,7 +1510,26 @@ def decode_attention_fwd(
...
@@ -1500,7 +1510,26 @@ def decode_attention_fwd(
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
b_start_loc
=
torch
.
arange
(
0
,
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
],
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
]
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
b_start_loc
=
torch
.
arange
(
0
,
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
]
*
page_size
,
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
]
*
page_size
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
current_device
=
torch
.
cuda
.
current_device
()
props
=
torch
.
cuda
.
get_device_properties
(
current_device
)
cu_num
=
props
.
multi_processor_count
num_b
=
min
(
kv_group_num
,
16
)
grid_num
=
(
q
.
shape
[
1
]
+
num_b
-
1
)
//
num_b
*
q
.
shape
[
0
]
L
=
req_to_token
.
shape
[
1
]
*
page_size
if
grid_num
*
num_kv_splits
<
cu_num
:
num_kv_splits
=
(
cu_num
-
1
+
grid_num
)
//
grid_num
#[TODO] The relationship between L and block is to be analyzed
if
L
>=
2048
:
num_kv_splits
=
(
2
*
cu_num
-
1
+
grid_num
)
//
grid_num
if
L
>=
4096
:
num_kv_splits
=
(
4
*
cu_num
-
1
+
grid_num
)
//
grid_num
attn_logits_v2
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
# MHA
# MHA
decode_attention_fwd_normal
(
decode_attention_fwd_normal
(
...
@@ -1551,22 +1580,6 @@ def decode_attention_fwd(
...
@@ -1551,22 +1580,6 @@ def decode_attention_fwd(
page_size,
page_size,
logit_cap,
logit_cap,
)'''
)'''
current_device
=
torch
.
cuda
.
current_device
()
props
=
torch
.
cuda
.
get_device_properties
(
current_device
)
cu_num
=
props
.
multi_processor_count
num_b
=
min
(
kv_group_num
,
16
)
grid_num
=
(
q
.
shape
[
1
]
+
num_b
-
1
)
//
num_b
*
q
.
shape
[
0
]
L
=
req_to_token
.
shape
[
1
]
*
page_size
if
grid_num
*
num_kv_splits
<
cu_num
:
num_kv_splits
=
(
cu_num
-
1
+
grid_num
)
//
grid_num
#[TODO] The relationship between L and block is to be analyzed
if
L
>=
2048
:
num_kv_splits
=
(
2
*
cu_num
-
1
+
grid_num
)
//
grid_num
attn_logits_v2
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
if
best_config
[
'kernel_kind'
]
==
'v1_2stages_tc'
:
if
best_config
[
'kernel_kind'
]
==
'v1_2stages_tc'
:
attn_logits_v1
=
torch
.
empty
(
attn_logits_v1
=
torch
.
empty
(
...
@@ -1589,7 +1602,7 @@ def decode_attention_fwd(
...
@@ -1589,7 +1602,7 @@ def decode_attention_fwd(
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
)
)
elif
best_config
[
'kernel_kind'
]
==
'v2_tc'
:
elif
best_config
[
'kernel_kind'
]
==
'v2_tc'
:
decode_attention_v
2
(
decode_attention_v
1
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment