Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1e5c712b
Commit
1e5c712b
authored
Apr 29, 2022
by
wangshaojie6
Browse files
add template to distinguish the instance that need lds padding for wrw
parent
93871ca1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
129 deletions
+142
-129
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+6
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+136
-127
No files found.
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
1e5c712b
...
...
@@ -244,7 +244,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
true
,
true
>
;
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
...
...
@@ -285,7 +287,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
true
,
true
>
;
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
1e5c712b
...
...
@@ -11,9 +11,6 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#define A_BLOCK_BANK_CONFLICT_FREE_WRW 1
#define B_BLOCK_BANK_CONFLICT_FREE_WRW 1
namespace
ck
{
template
<
typename
GridwiseGemm
,
...
...
@@ -112,7 +109,9 @@ template <index_t BlockSize,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
ABlockLdsExtraM1Wrw
=
false
,
bool
BBlockLdsExtraN1Wrw
=
false
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -126,21 +125,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// Bytes per 32 lds bank: 32 * 4 bytes
static
constexpr
auto
BankLength
=
Number
<
128
>
{};
static
constexpr
auto
ElePerBank
=
Number
<
BankLength
/
sizeof
(
FloatAB
)
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// M1 & N1
static
constexpr
auto
ElePerBank
=
Number
<
BankLength
/
sizeof
(
FloatAB
)
>
{};
// M1 & M0
static
constexpr
auto
M1PerBlock
=
Number
<
ElePerBank
/
K1Value
>
{};
static
constexpr
auto
N1PerBlock
=
Number
<
ElePerBank
/
K1Value
>
{};
// M0 & N0
static
constexpr
auto
M0PerBlock
=
Number
<
MPerBlock
/
M1PerBlock
>
{};
static
constexpr
auto
N0PerBlock
=
Number
<
NPerBlock
/
M1PerBlock
>
{}
;
static
constexpr
auto
M1Padding
=
I4
;
// M1 padding num
static
constexpr
auto
M1Padding
=
Number
<
4
>
{};
static
constexpr
auto
N1Padding
=
M1Padding
;
// N1 & N0
static
constexpr
auto
N1PerBlock
=
Number
<
ElePerBank
/
K1Value
>
{};
static
constexpr
auto
N0PerBlock
=
Number
<
NPerBlock
/
M1PerBlock
>
{};
static
constexpr
auto
N1Padding
=
I4
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
@@ -150,30 +148,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
#if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_block_desc_k0_m_k1_tmp
;
#else
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
#endif
if
constexpr
(
ABlockLdsExtraM1Wrw
)
{
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_block_desc_k0_m_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
...
...
@@ -193,39 +194,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
a_block_desc_b_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
#if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr
auto
a_block_desc_b_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_b_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_b_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod_for_wrw
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
a_block_desc_b_k0_m_k1_tmp
;
#else
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
#endif
if
constexpr
(
ABlockLdsExtraM1Wrw
)
{
constexpr
auto
a_block_desc_b_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_b_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_b_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod_for_wrw
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
a_block_desc_b_k0_m_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
...
...
@@ -246,31 +250,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
#if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_block_desc_k0_n_k1_tmp
;
#else
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
#endif
if
constexpr
(
BBlockLdsExtraN1Wrw
)
{
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_block_desc_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
...
...
@@ -290,39 +296,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
b_block_desc_b_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
#if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr
auto
b_block_desc_b_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_b_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_b_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod_for_wrw
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_b_k0_n_k1_tmp
;
#else
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
#endif
if
constexpr
(
BBlockLdsExtraN1Wrw
)
{
constexpr
auto
b_block_desc_b_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_b_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_b_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod_for_wrw
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_b_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment