Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3474c777
Commit
3474c777
authored
Jul 20, 2022
by
Chao Liu
Browse files
add gemm padding to convnd
parent
7cc806d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
219 additions
and
167 deletions
+219
-167
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
+8
-4
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
...device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
+211
-163
No files found.
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
View file @
3474c777
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#include "convnd_fwd_common.hpp"
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
...
@@ -20,10 +18,10 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -20,10 +18,10 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
#if 0
static constexpr auto ConvFwdDefault =
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
#if 0
template <ck::index_t NDimSpatial>
template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl<
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl<
NDimSpatial, //
NDimSpatial, //
...
@@ -63,6 +61,11 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
...
@@ -63,6 +61,11 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
#else
#else
using
CShuffleDataType
=
ck
::
half_t
;
using
CShuffleDataType
=
ck
::
half_t
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvNDFwdInstance
=
using
DeviceConvNDFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
<
...
@@ -76,7 +79,8 @@ using DeviceConvNDFwdInstance =
...
@@ -76,7 +79,8 @@ using DeviceConvNDFwdInstance =
InElementOp
,
// Input Elementwise Operation
InElementOp
,
// Input Elementwise Operation
WeiElementOp
,
// Weights Elementwise Operation
WeiElementOp
,
// Weights Elementwise Operation
OutElementOp
,
// Output Elementwise Operation
OutElementOp
,
// Output Elementwise Operation
ConvFwdDefault
,
// ConvForwardSpecialization
ConvSpec
,
// ConvForwardSpecialization
GemmSpec
,
// GemmSpecialization
1
,
//
1
,
//
256
,
// BlockSize
256
,
// BlockSize
128
,
// MPerBlock
128
,
// MPerBlock
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
View file @
3474c777
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
...
@@ -118,6 +120,7 @@ template <index_t NDimSpatial,
...
@@ -118,6 +120,7 @@ template <index_t NDimSpatial,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -181,15 +184,25 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -181,15 +184,25 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
auto
GetWeightTensorDescriptor
(
index_t
GemmN
,
index_t
GemmK
)
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
GetWeightTensorDescriptor
(
index_t
GemmNRaw
,
index_t
GemmKRaw
)
{
{
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
wei_k_yxc_grid_desc
=
const
auto
wei_k_yxc_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmN
,
GemmK
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmNRaw
,
GemmKRaw
));
const
auto
wei_gemmn_gemmk_grid_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_k_yxc_grid_desc
);
const
auto
GemmN
=
wei_gemmn_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmn_gemmk_grid_desc
.
GetLength
(
I1
);
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
// wei_gemmk0_gemmn_gemmk1_grid_desc
// wei_gemmk0_gemmn_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
wei_
k_yxc
_grid_desc
,
wei_
gemmn_gemmk
_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
...
@@ -198,25 +211,22 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -198,25 +211,22 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
auto
GetOutputTensorDescriptor
(
index_t
GemmMRaw
,
index_t
GemmN
)
static
auto
GetOutputTensorDescriptor
(
index_t
GemmMRaw
,
index_t
GemmN
)
{
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
auto
out_gemmmraw_gemmn_grid_desc
=
const
auto
out_gemmmraw_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmM
,
GemmN
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmM
,
GemmN
));
// out_gemmm_gemmn_grid_desc
const
auto
out_gemmm_gemmn_grid_desc
=
return
transform_tensor_descriptor
(
out_gemmmraw_gemmn_grid_desc
,
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmn_grid_desc
);
make_tuple
(
make_right_pad_transform
(
GemmM
,
GemmMPad
),
make_pass_through_transform
(
GemmN
)),
return
out_gemmm_gemmn_grid_desc
;
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
index_t
N
,
static
auto
GetInputTensorDescriptor
(
index_t
N
,
index_t
C
,
index_t
C
,
index_t
GemmMRaw
,
index_t
GemmMRaw
,
index_t
GemmK
,
index_t
GemmK
Raw
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
...
@@ -225,10 +235,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -225,10 +235,6 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
vector
<
index_t
>&
input_right_pads
)
{
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
...
@@ -237,45 +243,60 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -237,45 +243,60 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmMRaw
,
GemmK
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmMRaw
,
GemmKRaw
));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmm
raw
_gemmk_grid_desc
,
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_
right_pad
_transform
(
GemmM
Raw
,
GemmMPad
)),
make_
pass_through
_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_wi_
e
_grid_desc
=
const
auto
in_n_wi_
c
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_
e
_grid_desc
,
in_n_wi_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmmraw_gemmkraw_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_e_grid_desc
,
in_n_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_
merge_transform
(
make_tuple
(
N
,
Wo
)
)),
make_
pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
GemmMPad
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
else
else
{
{
...
@@ -284,19 +305,19 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -284,19 +305,19 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_n_wi_
e
_grid_desc
=
const
auto
in_n_wi_
c
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wip_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_wip_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_
e
_grid_desc
,
in_n_wi_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_x_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_
e
_grid_desc
,
in_n_wip_
c
_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
...
@@ -304,28 +325,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -304,28 +325,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemm
k
_gemm
mraw
_grid_desc
=
const
auto
in_gemm
mraw
_gemm
k
_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_
e
_grid_desc
,
transform_tensor_descriptor
(
in_n_x_wo_
c
_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
X
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmm_gemmk_grid_desc
=
in_gemmk_gemmmraw_grid_desc
,
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
GemmMPad
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
}
}
...
@@ -333,7 +355,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -333,7 +355,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
auto
GetInputTensorDescriptor
(
index_t
N
,
static
auto
GetInputTensorDescriptor
(
index_t
N
,
index_t
C
,
index_t
C
,
index_t
GemmMRaw
,
index_t
GemmMRaw
,
index_t
GemmK
,
index_t
GemmK
Raw
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
...
@@ -342,12 +364,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -342,12 +364,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
vector
<
index_t
>&
input_right_pads
)
{
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
...
@@ -358,25 +376,33 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -358,25 +376,33 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmkraw_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmM
,
GemmK
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmMRaw
,
GemmKRaw
));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmm
raw
_gemmk_grid_desc
,
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_
right_pad
_transform
(
GemmM
,
GemmMPad
)),
make_
pass_through
_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_hi_wi_
e
_grid_desc
=
const
auto
in_n_hi_wi_
c
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_ho_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_ho_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_
e
_grid_desc
,
in_n_hi_wi_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
...
@@ -384,21 +410,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -384,21 +410,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmmraw_gemmk_grid_desc
=
in_n_ho_wo_e_grid_desc
,
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_
merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)
)),
make_
pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
GemmMPad
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
else
else
{
{
...
@@ -414,11 +448,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -414,11 +448,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
auto
in_n_hi_wi_
e
_grid_desc
=
const
auto
in_n_hi_wi_
c
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_hip_wip_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_
e
_grid_desc
,
in_n_hi_wi_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
...
@@ -426,8 +460,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -426,8 +460,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_
e
_grid_desc
,
in_n_hip_wip_
c
_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
@@ -436,29 +470,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -436,29 +470,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemm
k
_gemm
mraw
_grid_desc
=
const
auto
in_gemm
mraw
_gemm
k
_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_
e
_grid_desc
,
transform_tensor_descriptor
(
in_n_y_ho_x_wo_
c
_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmm_gemmk_grid_desc
=
in_gemmk_gemmmraw_grid_desc
,
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
transform_tensor_descriptor
(
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
make_right_pad_transform
(
GemmM
,
GemmMPad
),
make_pass_through_transform
(
GemmK1Number
)),
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
in_gemmm_gemmk_grid_desc
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
}
}
}
}
...
@@ -466,7 +500,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -466,7 +500,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
auto
GetInputTensorDescriptor
(
index_t
N
,
static
auto
GetInputTensorDescriptor
(
index_t
N
,
index_t
C
,
index_t
C
,
index_t
GemmMRaw
,
index_t
GemmMRaw
,
index_t
GemmK
,
index_t
GemmK
Raw
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
...
@@ -475,13 +509,9 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -475,13 +509,9 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
vector
<
index_t
>&
input_right_pads
)
{
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
GemmMPad
=
GemmM
-
GemmMRaw
;
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
...
@@ -494,25 +524,33 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -494,25 +524,33 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmkraw_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmM
,
GemmK
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmMRaw
,
GemmKRaw
));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
in_gemmm
raw
_gemmk_grid_desc
,
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_
right_pad
_transform
(
GemmM
,
GemmMPad
)),
make_
pass_through
_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_di_hi_wi_
e
_grid_desc
=
const
auto
in_n_di_hi_wi_
c
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_do_ho_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_do_ho_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_
e
_grid_desc
,
in_n_di_hi_wi_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
...
@@ -523,22 +561,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -523,22 +561,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmmraw_gemmkraw_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_e_grid_desc
,
in_n_do_ho_wo_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
make_tuple
(
Sequence
<
4
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{}),
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
GemmMPad
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
else
else
{
{
...
@@ -558,11 +603,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -558,11 +603,11 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
auto
in_n_di_hi_wi_
e
_grid_desc
=
const
auto
in_n_di_hi_wi_
c
_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_hip_wip_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_
e
_grid_desc
,
in_n_di_hi_wi_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
@@ -573,8 +618,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -573,8 +618,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_
e
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_z_do_y_ho_x_wo_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_
e
_grid_desc
,
in_n_hip_wip_
c
_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
...
@@ -589,28 +634,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -589,28 +634,29 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
Sequence
<
5
,
6
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
const
auto
in_gemm
k
_gemm
m
raw_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemm
mraw
_gemm
k
raw_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_
e
_grid_desc
,
in_n_z_do_y_ho_x_wo_
c
_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmm_gemmk_grid_desc
=
in_gemmk_gemmmraw_grid_desc
,
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmM
,
GemmMPad
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
}
}
...
@@ -871,12 +917,14 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -871,12 +917,14 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
{
{
#if 0
#if 0
{
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0)
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0)
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment