Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9679ba63
Commit
9679ba63
authored
Aug 08, 2023
by
letaoqin
Browse files
add verify d0 vector load
parent
5d6bfabb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
2 deletions
+30
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
.../device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
+26
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
...n/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
+4
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
9679ba63
...
@@ -358,6 +358,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -358,6 +358,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
B1Spec
,
B1Spec
,
CSpec
>
;
CSpec
>
;
using
RawTransform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpecialization
::
Default
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
{
...
@@ -563,8 +572,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -563,8 +572,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
true
,
Acc0BiasTransferSrcScalarPerVector
,
BBlockLdsExtraN
,
BBlockLdsExtraN
,
Acc0BiasTransferSrcScalarPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcAccessOrder
,
...
@@ -639,6 +648,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -639,6 +648,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// for gridwise gemm check
// for gridwise gemm check
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
// raw data
int
raw_d0_n_
;
};
};
// Argument
// Argument
...
@@ -820,6 +832,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -820,6 +832,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
z_random_matrix_offset
=
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
auto
raw_d0_m_n
=
NumD0Tensor
==
0
?
RawTransform
::
MakeCGridDescriptor_M_N
({},
{})
:
RawTransform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
0
],
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
0
]);
group_device_args_
.
push_back
(
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
...
@@ -833,7 +850,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -833,7 +850,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc
.
b1_gs_os_ns_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
problem_desc
.
b1_gs_os_ns_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
});
c_grid_desc_m_n
,
NumD0Tensor
==
0
?
0
:
raw_d0_m_n
.
GetLength
(
I1
)});
}
}
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
...
@@ -1048,6 +1066,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1048,6 +1066,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
index_t
c_gemm1n
=
device_arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
c_gemm1n
=
device_arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
device_arg
.
raw_d0_n_
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
{
{
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
9679ba63
...
@@ -95,6 +95,10 @@ template <typename FloatAB,
...
@@ -95,6 +95,10 @@ template <typename FloatAB,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{
{
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
4
,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
using
DDataType
=
FloatAB
;
using
DDataType
=
FloatAB
;
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment