Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f00dab9f
Commit
f00dab9f
authored
Mar 06, 2023
by
aska-0096
Browse files
conv A-skip lds ported
parent
a38ce024
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
161 additions
and
119 deletions
+161
-119
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+5
-5
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+66
-36
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+85
-73
No files found.
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
f00dab9f
...
@@ -82,13 +82,13 @@ using DeviceOpInstance =
...
@@ -82,13 +82,13 @@ using DeviceOpInstance =
GemmSpec
,
GemmSpec
,
256
,
256
,
128
,
128
,
256
,
128
,
8
,
32
,
8
,
8
,
16
,
16
,
16
,
16
,
4
,
1
,
4
,
8
,
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -105,7 +105,7 @@ using DeviceOpInstance =
...
@@ -105,7 +105,7 @@ using DeviceOpInstance =
true
,
true
,
1
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
1
,
128
,
1
,
2
>
,
8
>
;
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
...
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
f00dab9f
...
@@ -53,13 +53,13 @@ using DeviceConvFwdInstance =
...
@@ -53,13 +53,13 @@ using DeviceConvFwdInstance =
GemmSpec
,
// GemmSpecialization
GemmSpec
,
// GemmSpecialization
256
,
// BlockSize
256
,
// BlockSize
128
,
// MPerBlock
128
,
// MPerBlock
256
,
// NPerBlock
128
,
// NPerBlock
4
,
// K
0
PerBlock
32
,
// KPerBlock
8
,
// K1
8
,
// K1
16
,
// MPerWMMA
16
,
// MPerWMMA
16
,
// NPerWMMA
16
,
// NPerWMMA
4
,
// MRepeat
1
,
// MRepeat
4
,
// NRepeat
8
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
...
@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
...
@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
true
,
// BBlockLdsExtraN
true
,
// BBlockLdsExtraN
1
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
1
,
128
,
1
,
2
>
,
8
>
;
8
>
;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
f00dab9f
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -38,10 +40,10 @@ template <typename ALayout,
...
@@ -38,10 +40,10 @@ template <typename ALayout,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
@@ -83,19 +85,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -83,19 +85,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
// K1 = Max Vector Access Pixels
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
{
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
assert
(
K
%
K1
==
0
)
;
static
constexpr
auto
WmmaK
=
16
;
const
index_t
K0
=
K
/
K1
;
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Describe how data read from Global memory
static
auto
MakeAGridDescriptor
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
#ifdef ENABLE_COLMAJOR
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
...
@@ -105,25 +123,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -105,25 +123,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
#endif
#endif
}();
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
AEnableLds
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_
right_pad
_transform
(
M
,
PadM
)),
make_
pass_through
_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
else
{
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_pass_through_transform
(
M
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
}
}
...
@@ -216,7 +244,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -216,7 +244,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
// Gridwise descriptor, mapping to whole given provblem.
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
_K0_M_K1
=
decltype
(
MakeAGridDescriptor
_K0_M_K1
(
1
,
1
,
1
));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
...
@@ -231,7 +259,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -231,7 +259,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
// InMemory Data Descriptor
// InMemory Data Descriptor
AGridDesc
_K0_M_K1
,
AGridDesc
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_M_N
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
EGridDesc_M_N
,
...
@@ -243,9 +271,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -243,9 +271,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// Tiling Family
// Tiling Family
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K
0
PerBlock
,
KPerBlock
,
MPerW
MMA
,
MPerW
mma
,
NPerW
MMA
,
NPerW
mma
,
K1
,
K1
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
...
@@ -258,6 +286,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -258,6 +286,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds
,
ABlockLdsAddExtraM
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
...
@@ -266,6 +295,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -266,6 +295,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds
,
BBlockLdsAddExtraN
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
@@ -298,7 +328,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -298,7 +328,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc
_k0_m_k1_
{},
a_grid_desc
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
...
@@ -311,7 +341,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -311,7 +341,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
cde_element_op_
{
cde_element_op
}
{
{
a_grid_desc
_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor
_K0_M_K1
(
M
,
K
,
StrideA
);
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
...
@@ -328,7 +358,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -328,7 +358,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
_k0_m_k1_
,
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b_grid_desc_k0_n_k1_
,
b_grid_desc_k0_n_k1_
,
ds_grid_desc_m_n_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
e_grid_desc_m_n_
,
...
@@ -351,7 +381,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -351,7 +381,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
// Tensor Descriptors
// Tensor Descriptors
AGridDesc
_K0_M_K1
a_grid_desc
_k0_m_k1_
;
AGridDesc
a_grid_desc
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
...
@@ -382,9 +412,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -382,9 +412,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
#if 0
#if 0
{
{
std::cout << "arg.a_grid_desc
_k0_m_k1_
{" << arg.a_grid_desc
_k0_m_k1_
.GetLength(I0)
std::cout << "arg.a_grid_desc{" << arg.a_grid_desc.GetLength(I0)
<< ", " << arg.a_grid_desc
_k0_m_k1_
.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc.GetLength(I1) << ", "
<< arg.a_grid_desc
_k0_m_k1_
.GetLength(I2) << "}" << std::endl;
<< arg.a_grid_desc.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
...
@@ -396,7 +426,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -396,7 +426,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
#endif
#endif
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
_k0_m_k1_
,
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
@@ -410,7 +440,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -410,7 +440,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc
_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
_k0_m_k1_
.
GetLength
(
I2
);
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -422,7 +452,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -422,7 +452,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BDataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
_K0_M_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
...
@@ -444,7 +474,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -444,7 +474,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
_k0_m_k1_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -461,7 +491,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -461,7 +491,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BDataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
_K0_M_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
...
@@ -483,7 +513,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -483,7 +513,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
_k0_m_k1_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -524,7 +554,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -524,7 +554,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
return
false
;
return
false
;
}
}
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
_k0_m_k1_
,
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
@@ -630,10 +660,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -630,10 +660,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
K1
<<
", "
<<
MPerW
MMA
<<
", "
<<
MPerW
mma
<<
", "
<<
NPerW
MMA
<<
", "
<<
NPerW
mma
<<
", "
<<
MRepeat
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
NRepeat
<<
">"
<<
">"
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
f00dab9f
...
@@ -112,10 +112,10 @@ template <index_t NDimSpatial,
...
@@ -112,10 +112,10 @@ template <index_t NDimSpatial,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
...
@@ -157,11 +157,25 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -157,11 +157,25 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
K1
;
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static
constexpr
auto
conv_to_gemm_transformer
=
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
...
@@ -171,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -171,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
static
auto
MakeAGridDescriptor
_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
MakeAGridDescriptor
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
...
@@ -196,13 +210,42 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -196,13 +210,42 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const
auto
in_gemmm_gemmk_desc
=
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
const
auto
M
=
in_gemmm_gemmk_desc
.
GetLength
(
I0
);
const
auto
K
=
in_gemmm_gemmk_desc
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
return
in_gemmm_gemmk_desc
;
if
constexpr
(
AEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
in_gemmm_gemmk_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
return
transform_tensor_descriptor
(
in_gemmm_gemmk_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
}
template
<
typename
BLay
>
template
<
typename
BLay
>
static
auto
static
auto
MakeBGridDescriptor_
N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
MakeBGridDescriptor_
BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
...
@@ -211,8 +254,18 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -211,8 +254,18 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const
auto
wei_gemmn_gemmk_desc
=
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
const
auto
N
=
wei_gemmn_gemmk_desc
.
GetLength
(
I0
);
const
auto
K
=
wei_gemmn_gemmk_desc
.
GetLength
(
I1
);
const
auto
BK1
=
K1
;
const
auto
BK0
=
K
/
BK1
;
return
wei_gemmn_gemmk_desc
;
return
transform_tensor_descriptor
(
wei_gemmn_gemmk_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
ELay
>
template
<
typename
ELay
>
...
@@ -245,50 +298,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -245,50 +298,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
}
// desc for problem definition
// desc for problem definition
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// A desc for source in blockwise copy
using
AGridDesc
=
decltype
(
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}));
template
<
typename
AGridDesc_M_K
>
using
BGridDesc_BK0_N_BK1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
({},
{}));
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK1
=
K1
;
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK1
=
K1
;
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}));
// GridwiseOp
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
using
GridwiseOp
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
...
@@ -300,7 +314,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -300,7 +314,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
// InMemory Data Descriptor
// InMemory Data Descriptor
AGridDesc
_AK0_M_AK1
,
AGridDesc
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
DsGridDesc_M_N
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
EGridDesc_M_N
,
...
@@ -312,9 +326,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -312,9 +326,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// Tiling Family
// Tiling Family
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K
0
PerBlock
,
KPerBlock
,
MPerW
MMA
,
MPerW
mma
,
NPerW
MMA
,
NPerW
mma
,
K1
,
K1
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
...
@@ -327,6 +341,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -327,6 +341,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
false
,
AEnableLds
,
ABlockLdsExtraM
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
...
@@ -335,6 +350,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -335,6 +350,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
false
,
BEnableLds
,
BBlockLdsExtraN
,
BBlockLdsExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
@@ -375,7 +391,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -375,7 +391,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
...
@@ -385,13 +404,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -385,13 +404,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
)},
input_right_pads
)},
b_grid_desc_
n_k
_
{
DeviceOp
::
MakeBGridDescriptor_
N_K
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_grid_desc_
bk0_n_bk1
_
{
DeviceOp
::
MakeBGridDescriptor_
BK0_N_BK1
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
b_g_k_c_xs_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
)},
block_2_etile_map_
{
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
)},
...
@@ -443,8 +457,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -443,8 +457,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
void
Print
()
const
void
Print
()
const
{
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc
_m_k_
<<
std
::
endl
;
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_
n_k
_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_
bk0_n_bk1
_
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
...
@@ -459,13 +473,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -459,13 +473,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc
_AK0_M_AK1
a_grid_desc
_ak0_m_ak1_
;
AGridDesc
a_grid_desc
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
...
@@ -514,7 +526,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -514,7 +526,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
num_group_
;
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
num_group_
;
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc
_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
...
@@ -528,7 +540,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -528,7 +540,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc
_AK0_M_AK1
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -549,7 +561,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -549,7 +561,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc
_ak0_m_ak1_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
@@ -719,7 +731,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -719,7 +731,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
}
// check Gridwise GEMM
// check Gridwise GEMM
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
_ak0_m_ak1_
,
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment