Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b997ce4
Commit
0b997ce4
authored
Jul 18, 2022
by
Chao Liu
Browse files
adding conv multiple D
parent
69d323de
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
373 additions
and
257 deletions
+373
-257
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
+38
-35
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
...device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
+293
-211
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+4
-4
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+38
-7
No files found.
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
View file @
0b997ce4
...
...
@@ -16,7 +16,7 @@ using S = ck::Sequence<Is...>;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
...
...
@@ -48,18 +48,18 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLds
Add
ExtraM
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLds
Add
ExtraN
true, // BBlockLdsExtraN
7, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector
#else
using
CShuffleDataType
=
floa
t
;
using
CShuffleDataType
=
ck
::
half_
t
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvNDFwdInstance
=
...
...
@@ -69,16 +69,17 @@ using DeviceConvNDFwdInstance =
WeiDataType
,
//
AccDataType
,
//
CShuffleDataType
,
//
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
//
OutDataType
,
//
InElementOp
,
// Input Elementwise Operation
WeiElementOp
,
// Weights Elementwise Operation
OutElementOp
,
// Output Elementwise Operation
ConvFwdDefault
,
// ConvForwardSpecialization
1
,
//
256
,
// BlockSize
128
,
// MPerBlock
256
,
// NPerBlock
4
,
// K
0
PerBlock
32
,
// KPerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
...
...
@@ -90,16 +91,18 @@ using DeviceConvNDFwdInstance =
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLds
Add
ExtraM
1
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
7
,
// CThreadTransferSrcDstVectorDim
1
>
;
// CThreadTransferDstScalarPerVector
1
,
// BBlockLdsExtraN
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
#endif
int
main
(
int
argc
,
char
*
argv
[])
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
View file @
0b997ce4
...
...
@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
...
...
@@ -23,6 +24,73 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace
//
// @brief Device Convolution operation.
//
...
...
@@ -39,7 +107,7 @@ namespace device {
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template
<
ck
::
index_t
NDimSpatial
,
template
<
index_t
NDimSpatial
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
...
...
@@ -50,31 +118,35 @@ template <ck::index_t NDimSpatial,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
K1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLds
Add
ExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
A
K1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
:
public
DeviceConvFwd
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
...
...
@@ -96,8 +168,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -109,12 +184,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_n
,
ck
::
index_t
gemm_k
)
{
const
ck
::
index_t
gemm_k0
=
gemm_k
/
GemmK1Number
;
const
auto
wei_k_yx
c
_grid_desc
=
const
auto
wei_k_yx
e
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
// wei_gemmk0_gemmn_gemmk1_grid_desc
return
transform_tensor_descriptor
(
wei_k_yx
c
_grid_desc
,
wei_k_yx
e
_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
gemm_k0
,
GemmK1Number
)),
make_pass_through_transform
(
gemm_n
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
...
...
@@ -149,6 +224,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
ck
::
index_t
gemm_k0
=
gemm_k
/
GemmK1Number
;
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
...
...
@@ -171,11 +247,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
const
auto
in_n_wi_
c
_grid_desc
=
const
auto
in_n_wi_
e
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_
c
_grid_desc
,
const
auto
in_n_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
...
...
@@ -183,7 +259,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_
c
_grid_desc
,
in_n_wo_
e
_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
gemm_k0
,
GemmK1Number
)),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
0
,
1
>
{}),
...
...
@@ -205,19 +281,19 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_n_wi_
c
_grid_desc
=
const
auto
in_n_wi_
e
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wip_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_
c
_grid_desc
,
const
auto
in_n_wip_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_
c
_grid_desc
,
const
auto
in_n_x_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
...
...
@@ -226,7 +302,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_
c
_grid_desc
,
transform_tensor_descriptor
(
in_n_x_wo_
e
_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{}),
...
...
@@ -291,11 +367,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
const
auto
in_n_hi_wi_
c
_grid_desc
=
const
auto
in_n_hi_wi_
e
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_ho_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_
c
_grid_desc
,
const
auto
in_n_ho_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
...
...
@@ -304,7 +380,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_
c
_grid_desc
,
in_n_ho_wo_
e
_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
gemm_k0
,
GemmK1Number
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
...
...
@@ -333,11 +409,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
auto
in_n_hi_wi_
c
_grid_desc
=
const
auto
in_n_hi_wi_
e
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_
c
_grid_desc
,
const
auto
in_n_hip_wip_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
...
...
@@ -345,8 +421,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_
c
_grid_desc
,
const
auto
in_n_y_ho_x_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
...
@@ -356,7 +432,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_
c
_grid_desc
,
transform_tensor_descriptor
(
in_n_y_ho_x_wo_
e
_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
...
...
@@ -372,6 +448,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
gemm_k0
),
make_right_pad_transform
(
gemm_m
,
gemm_m_pad
),
make_pass_through_transform
(
GemmK1Number
)),
...
...
@@ -424,11 +501,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
const
auto
in_n_di_hi_wi_
c
_grid_desc
=
const
auto
in_n_di_hi_wi_
e
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_do_ho_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_
c
_grid_desc
,
const
auto
in_n_do_ho_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
...
...
@@ -440,9 +517,10 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_
c
_grid_desc
,
in_n_do_ho_wo_
e
_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
gemm_k0
,
GemmK1Number
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
4
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
@@ -473,11 +551,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
auto
in_n_di_hi_wi_
c
_grid_desc
=
const
auto
in_n_di_hi_wi_
e
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_
c
_grid_desc
,
const
auto
in_n_hip_wip_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -488,8 +566,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_
c
_grid_desc
,
const
auto
in_n_z_do_y_ho_x_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_
e
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
...
...
@@ -505,7 +583,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Sequence
<
7
>
{}));
const
auto
in_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_
c
_grid_desc
,
in_n_z_do_y_ho_x_wo_
e
_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
...
...
@@ -547,7 +625,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
}
static
auto
MakeAB
C
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeAB
E
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
...
...
@@ -568,7 +646,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
assert
(
GemmK
%
GemmK1Number
==
0
);
// C = A^T*B
// A:
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
GetInputTensorDescriptor
<
NDimSpatial
>
(
N
,
...
...
@@ -585,7 +662,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_right_pads
);
// B:
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
GetWeightTensorDescriptor
(
GemmN
,
GemmK
);
//
C
:
//
E
:
const
auto
out_gemmm_gemmn_grid_desc
=
GetOutputTensorDescriptor
(
GemmMRaw
,
GemmN
,
GemmMPad
);
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
...
...
@@ -594,74 +671,84 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetAB
C
GridDesc
()
static
auto
GetAB
E
GridDesc
()
{
return
MakeAB
C
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
return
MakeAB
E
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetAB
C
GridDesc
()
static
auto
GetAB
E
GridDesc
()
{
return
MakeAB
C
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
return
MakeAB
E
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetAB
C
GridDesc
()
static
auto
GetAB
E
GridDesc
()
{
return
MakeAB
C
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
return
MakeAB
E
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
using
AB
C
GridDescs
=
decltype
(
GetAB
C
GridDesc
<
NDimSpatial
>
());
using
AB
E
GridDescs
=
decltype
(
GetAB
E
GridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I2
])
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
KPerBlock
,
K1
,
K1
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
//
ABlockTransferThreadClusterArrangeOrder,
Sequence
<
1
,
0
,
2
>
,
//
ABlockTransferSrcAccessOrder,
2
,
//
ABlockTransferSrcVectorDim,
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLds
Add
ExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
//
BBlockTransferThreadClusterArrangeOrder,
Sequence
<
1
,
0
,
2
>
,
//
BBlockTransferSrcAccessOrder,
2
,
//
BBlockTransferSrcVectorDim,
ABlockTransferDstScalarPerVector_
A
K1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
#if 0
using Block2ETileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, EGridDesc_M_N>;
#else
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
#endif
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -682,17 +769,18 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
AElementwiseOperation
in_element_op
,
BElementwiseOperation
wei_element_op
,
CDEElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_in_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
in_element_op_
{
in_element_op
},
wei_element_op_
{
wei_element_op
},
out_element_op_
{
out_element_op
},
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_in_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_wei_grid
)},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_out_grid
)},
a_grid_desc_ak0_m_ak1_
{},
b_grid_desc_bk0_n_bk1_
{},
e_grid_desc_m_n_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{},
a_element_op_
{
in_element_op
},
b_element_op_
{
wei_element_op
},
cde_element_op_
{
out_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
...
...
@@ -702,7 +790,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_right_pads_
{
input_right_pads
}
{
const
auto
descs
=
DeviceOp
::
MakeAB
C
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
DeviceOp
::
MakeAB
E
GridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
K
,
C
,
input_spatial_lengths
,
...
...
@@ -713,35 +801,50 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_left_pads
,
input_right_pads
);
a_grid_desc_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
c
_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_
a
k0_m_
a
k1_
=
descs
[
I0
];
b_grid_desc_
b
k0_n_
b
k1_
=
descs
[
I1
];
e
_grid_desc_m_n_
=
descs
[
I2
];
block_2_
c
tile_map_
=
Block2
C
TileMap
{
c
_grid_desc_m_n_
};
block_2_
e
tile_map_
=
Block2
E
TileMap
{
e
_grid_desc_m_n_
};
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c
_grid_desc_m_n_
,
block_2_
c
tile_map_
))
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
a
k0_m_
a
k1_
,
b_grid_desc_
b
k0_n_
b
k1_
,
e
_grid_desc_m_n_
,
block_2_
e
tile_map_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
EDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Block2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
in_element_op_
;
BElementwiseOperation
wei_element_op_
;
CDEElementwiseOperation
out_element_op_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
...
...
@@ -761,99 +864,84 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.a_grid_desc_
a
k0_m_
a
k1_{" << arg.a_grid_desc_
a
k0_m_
a
k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_
a
k0_m_
a
k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_
a
k0_m_
a
k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_
b
k0_n_
b
k1_{" << arg.b_grid_desc_
b
k0_n_
b
k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_
b
k0_n_
b
k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_
b
k0_n_
b
k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.
c
_grid_desc_m_n_{ " << arg.
c
_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.
c
_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.
e
_grid_desc_m_n_{ " << arg.
e
_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.
e
_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c
_grid_desc_m_n_
,
arg
.
block_2_
c
tile_map_
))
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
a
k0_m_
a
k1_
,
arg
.
b_grid_desc_
b
k0_n_
b
k1_
,
arg
.
e
_grid_desc_m_n_
,
arg
.
block_2_
e
tile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_
c
tile_map_
.
CalculateGridSize
(
arg
.
c
_grid_desc_m_n_
);
arg
.
block_2_
e
tile_map_
.
CalculateGridSize
(
arg
.
e
_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_
a
k0_m_
a
k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
a
k0_m_
a
k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_gemm_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
Block2CTileMap
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
float
avg_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
avg_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
EDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
Block2CTileMap
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
block_2_ctile_map_
);
avg_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
av
e
_time
;
return
av
g
_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
...
...
@@ -863,12 +951,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx908"
)
...
...
@@ -892,12 +974,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
return
false
;
}
//
Input
tensors can't be bigger than 2GB each.
// tensors can't be bigger than 2GB each.
constexpr
ck
::
long_index_t
GB2
=
(
ck
::
long_index_t
{
1
}
<<
31
);
if
(
arg
.
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
>
GB2
||
arg
.
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
>
GB2
||
arg
.
c
_grid_desc_m_n_
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
>
GB2
)
if
(
arg
.
a_grid_desc_
a
k0_m_
a
k1_
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
>
GB2
||
arg
.
b_grid_desc_
b
k0_n_
b
k1_
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
>
GB2
||
arg
.
e
_grid_desc_m_n_
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
>
GB2
)
{
return
false
;
}
...
...
@@ -937,17 +1019,17 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
return
false
;
}
// vector store
C
matrix into global memory
if
(
!
(
arg
.
Conv_K_
%
C
Thread
Transfer
Dst
ScalarPerVector
==
0
))
// vector store
D/E
matrix into global memory
if
(
!
(
arg
.
Conv_K_
%
C
DEBlock
TransferScalarPerVector
_NPerBlock
==
0
))
{
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c
_grid_desc_m_n_
,
arg
.
block_2_
c
tile_map_
);
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
a
k0_m_
a
k1_
,
arg
.
b_grid_desc_
b
k0_n_
b
k1_
,
arg
.
e
_grid_desc_m_n_
,
arg
.
block_2_
e
tile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -1043,7 +1125,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
0b997ce4
...
...
@@ -618,18 +618,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
block_2_etile_map_
);
};
float
av
e
_time
=
0
;
float
av
g
_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
av
e
_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
av
g
_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
av
e
_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
av
g
_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
av
e
_time
;
return
av
g
_time
;
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
0b997ce4
...
...
@@ -12,16 +12,47 @@ namespace element_wise {
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
};
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
};
struct
UnaryConvert
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
y
=
type_convert
<
Y
>
(
x
);
}
};
struct
Scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment