Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6f49c1ed
Commit
6f49c1ed
authored
Apr 15, 2025
by
zhuwenwen
Browse files
back to mla v2
parent
cf28e5a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
91 deletions
+37
-91
examples/mla/test_triton_decode_attention.py
examples/mla/test_triton_decode_attention.py
+1
-2
examples/mla/triton_decode_attention.py
examples/mla/triton_decode_attention.py
+14
-40
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+22
-49
No files found.
examples/mla/test_triton_decode_attention.py
View file @
6f49c1ed
...
@@ -213,4 +213,3 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
...
@@ -213,4 +213,3 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
with
open
(
file_name
,
'w'
)
as
file
:
with
open
(
file_name
,
'w'
)
as
file
:
json
.
dump
(
config_info
,
file
,
indent
=
1
)
json
.
dump
(
config_info
,
file
,
indent
=
1
)
#**************save config**************#
#**************save config**************#
examples/mla/triton_decode_attention.py
View file @
6f49c1ed
...
@@ -37,10 +37,7 @@ import triton.language as tl
...
@@ -37,10 +37,7 @@ import triton.language as tl
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
is_hip_
=
current_platform
.
is_rocm
()
is_hip_
=
current_platform
.
is_rocm
()
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"1"
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"0"
os
.
environ
[
"TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"
]
=
"0"
os
.
environ
[
"TRITON_DEFAULT_ENABLE_NUM_VGPRS512"
]
=
"1"
os
.
environ
[
"MLIR_ENABLE_DUMP"
]
=
"0"
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -760,12 +757,6 @@ def decode_attention_v1(
...
@@ -760,12 +757,6 @@ def decode_attention_v1(
triton
.
Config
({
"BLOCK_N"
:
32
,
"BLOCK_DIM"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
,
"BLOCK_DIM"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
,
"BLOCK_DIM"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
,
"BLOCK_DIM"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
,
"BLOCK_DIM"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
,
"BLOCK_DIM"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
,
"BLOCK_DIM"
:
64
},
num_warps
=
2
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_N"
:
16
,
"BLOCK_DIM"
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_N"
:
32
,
"BLOCK_DIM"
:
64
},
num_warps
=
2
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_N"
:
32
,
"BLOCK_DIM"
:
64
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_N"
:
64
,
"BLOCK_DIM"
:
32
},
num_warps
=
2
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_N"
:
64
,
"BLOCK_DIM"
:
32
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_N"
:
128
,
"BLOCK_DIM"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
,
"BLOCK_DIM"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
,
"BLOCK_DIM"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
,
"BLOCK_DIM"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
,
"BLOCK_DIM"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
,
"BLOCK_DIM"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
...
@@ -831,15 +822,15 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -831,15 +822,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
#
if BLOCK_DPE > 0:
if
BLOCK_DPE
>
0
:
#
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
#
mask_dpe = offs_dpe < Lk
mask_dpe
=
offs_dpe
<
Lk
#
off_qpe = (
off_qpe
=
(
#
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
#
)
)
#
qpe = tl.load(
qpe
=
tl
.
load
(
#
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
#
)
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
...
@@ -868,19 +859,11 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -868,19 +859,11 @@ def _decode_v2_kernel_stage1_use_tc(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
offs_buf_kpe
=
(
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
+
offs_dpe
[:,
None
]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kpe
=
tl
.
load
(
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
...
@@ -1018,6 +1001,7 @@ def _decode_v2_stage1_use_tc(
...
@@ -1018,6 +1001,7 @@ def _decode_v2_stage1_use_tc(
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
...
@@ -1160,20 +1144,11 @@ def decode_attentionv2_fwd(
...
@@ -1160,20 +1144,11 @@ def decode_attentionv2_fwd(
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
current_device
=
torch
.
cuda
.
current_device
()
props
=
torch
.
cuda
.
get_device_properties
(
current_device
)
cu_num
=
props
.
multi_processor_count
num_b
=
min
(
kv_group_num
,
16
)
num_b
=
min
(
kv_group_num
,
16
)
grid_num
=
(
q
.
shape
[
1
]
+
num_b
-
1
)
//
num_b
*
q
.
shape
[
0
]
grid_num
=
(
q
.
shape
[
1
]
+
num_b
-
1
)
//
num_b
*
q
.
shape
[
0
]
L
=
req_to_token
.
shape
[
1
]
*
page_size
L
=
req_to_token
.
shape
[
1
]
*
page_size
if
grid_num
*
num_kv_splits
<
cu_num
:
if
grid_num
*
num_kv_splits
<
128
:
num_kv_splits
=
(
cu_num
-
1
+
grid_num
)
//
grid_num
num_kv_splits
=
(
127
+
grid_num
)
//
grid_num
#[TODO] The relationship between L and block is to be analyzed
if
L
>=
2048
:
num_kv_splits
=
(
2
*
cu_num
-
1
+
grid_num
)
//
grid_num
if
L
>=
4096
:
num_kv_splits
=
(
4
*
cu_num
-
1
+
grid_num
)
//
grid_num
attn_logits_v1
=
torch
.
empty
(
attn_logits_v1
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -1263,4 +1238,3 @@ def decode_attentionv1_fwd(
...
@@ -1263,4 +1238,3 @@ def decode_attentionv1_fwd(
logit_cap
,
logit_cap
,
)
)
return
v1_tc_stage1_best_config
,
v1_tc_stage2_best_config
return
v1_tc_stage1_best_config
,
v1_tc_stage2_best_config
\ No newline at end of file
vllm/attention/ops/triton_decode_attention.py
View file @
6f49c1ed
...
@@ -39,10 +39,7 @@ from vllm import envs
...
@@ -39,10 +39,7 @@ from vllm import envs
# from ..backends.triton_config import KERNLE_KINDS
# from ..backends.triton_config import KERNLE_KINDS
is_hip_
=
current_platform
.
is_rocm
()
is_hip_
=
current_platform
.
is_rocm
()
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"1"
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"0"
os
.
environ
[
"TRITON_ENABLE_GLOBAL_TO_LOCAL_AND_NUMSTAGE2"
]
=
"0"
os
.
environ
[
"TRITON_DEFAULT_ENABLE_NUM_VGPRS512"
]
=
"1"
os
.
environ
[
"MLIR_ENABLE_DUMP"
]
=
"0"
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1071,12 +1068,6 @@ def decode_attention_v1(
...
@@ -1071,12 +1068,6 @@ def decode_attention_v1(
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=2),
# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=2),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
...
@@ -1142,15 +1133,15 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -1142,15 +1133,15 @@ def _decode_v2_kernel_stage1_use_tc(
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
#
if BLOCK_DPE > 0:
if
BLOCK_DPE
>
0
:
#
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
#
mask_dpe = offs_dpe < Lk
mask_dpe
=
offs_dpe
<
Lk
#
off_qpe = (
off_qpe
=
(
#
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
#
)
)
#
qpe = tl.load(
qpe
=
tl
.
load
(
#
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
#
)
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
...
@@ -1180,21 +1171,11 @@ def _decode_v2_kernel_stage1_use_tc(
...
@@ -1180,21 +1171,11 @@ def _decode_v2_kernel_stage1_use_tc(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
qk
+=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
offs_buf_kpe
=
(
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
+
offs_dpe
[:,
None
]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kpe
=
tl
.
load
(
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
...
@@ -1346,6 +1327,7 @@ def _decode_v2_stage1_use_tc(
...
@@ -1346,6 +1327,7 @@ def _decode_v2_stage1_use_tc(
# @triton.autotune(
# @triton.autotune(
# configs=[
# configs=[
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1),
...
@@ -1510,26 +1492,7 @@ def decode_attention_fwd(
...
@@ -1510,26 +1492,7 @@ def decode_attention_fwd(
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
b_start_loc
=
torch
.
arange
(
0
,
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
]
*
page_size
,
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
]
*
page_size
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
b_start_loc
=
torch
.
arange
(
0
,
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
],
req_to_token
.
shape
[
0
]
*
req_to_token
.
shape
[
1
]
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
current_device
=
torch
.
cuda
.
current_device
()
props
=
torch
.
cuda
.
get_device_properties
(
current_device
)
cu_num
=
props
.
multi_processor_count
num_b
=
min
(
kv_group_num
,
16
)
grid_num
=
(
q
.
shape
[
1
]
+
num_b
-
1
)
//
num_b
*
q
.
shape
[
0
]
L
=
req_to_token
.
shape
[
1
]
*
page_size
if
grid_num
*
num_kv_splits
<
cu_num
:
num_kv_splits
=
(
cu_num
-
1
+
grid_num
)
//
grid_num
#[TODO] The relationship between L and block is to be analyzed
if
L
>=
2048
:
num_kv_splits
=
(
2
*
cu_num
-
1
+
grid_num
)
//
grid_num
if
L
>=
4096
:
num_kv_splits
=
(
4
*
cu_num
-
1
+
grid_num
)
//
grid_num
attn_logits_v2
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
# MHA
# MHA
decode_attention_fwd_normal
(
decode_attention_fwd_normal
(
...
@@ -1580,6 +1543,16 @@ def decode_attention_fwd(
...
@@ -1580,6 +1543,16 @@ def decode_attention_fwd(
page_size,
page_size,
logit_cap,
logit_cap,
)'''
)'''
num_b
=
min
(
kv_group_num
,
16
)
grid_num
=
(
q
.
shape
[
1
]
+
num_b
-
1
)
//
num_b
*
q
.
shape
[
0
]
L
=
req_to_token
.
shape
[
1
]
*
page_size
if
grid_num
*
num_kv_splits
<
128
:
num_kv_splits
=
(
127
+
grid_num
)
//
grid_num
attn_logits_v2
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
],
num_kv_splits
,
v_buffer
.
shape
[
-
1
]
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
if
best_config
[
'kernel_kind'
]
==
'v1_2stages_tc'
:
if
best_config
[
'kernel_kind'
]
==
'v1_2stages_tc'
:
attn_logits_v1
=
torch
.
empty
(
attn_logits_v1
=
torch
.
empty
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment