Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d8998cbb
Commit
d8998cbb
authored
Sep 26, 2023
by
letaoqin
Browse files
add check code for vectorload
parent
13a0c55d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
49 additions
and
18 deletions
+49
-18
example/52_flash_atten_bias/grouped_mutihead_attention_bias_forward.cpp
...sh_atten_bias/grouped_mutihead_attention_bias_forward.cpp
+1
-0
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+1
-1
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
...ten_bias/run_grouped_multihead_attention_bias_forward.inc
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
...n/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp
...n/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp
+29
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+6
-5
No files found.
example/52_flash_atten_bias/grouped_mutihead_attention_bias_forward.cpp
View file @
d8998cbb
...
...
@@ -119,6 +119,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
d8998cbb
...
...
@@ -192,7 +192,7 @@ int run(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
+
size_t
(
M
)
*
N
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
BatchCount
;
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
View file @
d8998cbb
...
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
...
...
@@ -129,7 +129,7 @@ int run(int argc, char* argv[])
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
int
Batch
=
G0
*
G1
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
+
size_t
(
M
)
*
N
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp
View file @
d8998cbb
...
...
@@ -742,12 +742,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
&&
arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
)
{
return
false
;
if
(
!
(
arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
==
0
||
Transform
::
matrix_padder
.
PadN
))
return
false
;
}
if
(
arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
Acc0BiasTransferSrcScalarPerVector
!=
1
)
else
if
(
Acc0BiasTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
d8998cbb
...
...
@@ -1026,12 +1026,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
&&
arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
)
{
return
false
;
if
(
!
(
arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
==
0
||
Transform
::
matrix_padder
.
PadN
))
return
false
;
}
if
(
arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
Acc0BiasTransferSrcScalarPerVector
!=
1
)
else
if
(
Acc0BiasTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp
View file @
d8998cbb
...
...
@@ -180,6 +180,7 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
Acc0BiasTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
...
@@ -429,7 +430,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
4
,
Acc0BiasTransferSrcScalarPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
...
...
@@ -493,6 +494,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
// for gridwise gemm check
C1GridDesc_M_N
c1_grid_desc_m_n_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
};
// Argument
...
...
@@ -625,6 +629,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
BlockStart
,
BlockEnd
});
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride
;
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
...
...
@@ -638,7 +646,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
problem_desc
.
b1_gs_os_ns_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
});
c_grid_desc_m_n
,
d0_n_length_stride
});
}
}
...
...
@@ -774,6 +783,24 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
return
false
;
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
device_arg
.
d0_n_length_stride_
[
1
]
==
1
)
{
if
(
!
(
device_arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
==
0
||
Transform
::
matrix_padder
.
PadN
))
{
return
false
;
}
}
else
if
(
Acc0BiasTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
}
// Check if having main loop
const
auto
K
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
d8998cbb
...
...
@@ -1102,13 +1102,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
device_arg
.
d0_n_length_stride_
[
1
]
==
1
&&
device_arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
if
(
device_arg
.
d0_n_length_stride_
[
1
]
==
1
)
{
return
false
;
if
(
!
(
device_arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
==
0
||
Transform
::
matrix_padder
.
PadN
))
return
false
;
}
if
(
device_arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
Acc0BiasTransferSrcScalarPerVector
!=
1
)
else
if
(
Acc0BiasTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment