Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9dce6851
Commit
9dce6851
authored
Mar 10, 2022
by
Jing Zhang
Browse files
merge develop
parents
3cc57101
5d37d7bf
Changes
473
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1910 additions
and
0 deletions
+1910
-0
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp
+0
-0
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+185
-0
include/ck/tensor_operation/gpu/device/conv_utils.hpp
include/ck/tensor_operation/gpu/device/conv_utils.hpp
+0
-0
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+17
-0
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+0
-0
include/ck/tensor_operation/gpu/device/convolution_utility.hpp
...de/ck/tensor_operation/gpu/device/convolution_utility.hpp
+0
-0
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+0
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+0
-0
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+710
-0
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+814
-0
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+29
-0
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+29
-0
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+30
-0
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+2
-0
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
.../gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp
...nsor_operation/gpu/device/device_conv_backward_weight.hpp
+47
-0
include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp
...e/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp
+47
-0
include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp
include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp
+0
-0
include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp
..._operation/gpu/device/device_conv_fwd_bias_activation.hpp
+0
-0
No files found.
composable_kernel/
include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp
→
include/
ck/
tensor_operation/
gpu/block/
blockwise_tensor_slice_transfer_v6r3.hpp
View file @
9dce6851
File moved
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
0 → 100644
View file @
9dce6851
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
namespace
ck
{
template
<
typename
Buffer1dDescType
,
typename
AccDataType
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
bool
ReorderThreadClusters
,
typename
OpReduce
,
bool
PropagateNan
>
struct
PartitionedBlockwiseReductionOn1dBuffer
{
static
constexpr
auto
buffer_1d_desc
=
Buffer1dDescType
{};
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"The product of cluster lengths should be same as BlockSize!"
);
static_assert
(
KThreadClusterSize
>
1
,
"Parallel reduction need work on at least two elements"
);
static_assert
(
buffer_1d_desc
.
GetElementSize
()
==
BlockSize
,
"The buffer size should be the same as BlockSize!"
);
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
;
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
block_buffer
,
AccDataType
&
accuData
,
index_t
thread_m_cluster_id
,
index_t
thread_k_cluster_id
)
{
constexpr
auto
cluster_len_shift
=
get_shift
<
KThreadClusterSize
>
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
// consider the thread clusters order, ensure the contiguous locations are accessed
// by contiguous Thread-ID
index_t
offset1
=
ReorderThreadClusters
?
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
))
:
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
));
index_t
offset2
=
ReorderThreadClusters
?
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
(
thread_k_cluster_id
+
indOffset
)
*
MThreadClusterSize
+
thread_m_cluster_id
))
:
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
*
KThreadClusterSize
+
(
thread_k_cluster_id
+
indOffset
)));
AccDataType
opData1
=
type_convert
<
AccDataType
>
(
block_buffer
[
offset1
]);
AccDataType
opData2
=
type_convert
<
AccDataType
>
(
block_buffer
[
offset2
]);
Accumulation
::
Calculate
(
opData1
,
opData2
);
block_buffer
(
offset1
)
=
type_convert
<
AccDataType
>
(
opData1
);
}
__syncthreads
();
});
index_t
offset
=
ReorderThreadClusters
?
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
))
:
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
*
KThreadClusterSize
));
accuData
=
type_convert
<
AccDataType
>
(
block_buffer
[
offset
]);
};
};
template
<
typename
Buffer1dDescType
,
typename
AccDataType
,
typename
IndexDataType
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
bool
ReorderThreadClusters
,
typename
OpReduce
,
bool
PropagateNan
>
struct
PartitionedBlockwiseReductionWithIndexOn1dBuffer
{
static
constexpr
auto
buffer_1d_desc
=
Buffer1dDescType
{};
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"The product of cluster lengths should be same as BlockSize!"
);
static_assert
(
KThreadClusterSize
>
1
,
"Parallel reduction need work on at least two elements"
);
static_assert
(
buffer_1d_desc
.
GetElementSize
()
==
BlockSize
,
"The buffer size should be the same as BlockSize!"
);
using
Accumulation
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
,
IndexDataType
>
;
// This interface accumulates on both data values and indices
template
<
typename
BufferType
,
typename
IdxBufferType
>
__device__
static
void
Reduce
(
BufferType
&
block_val_buffer
,
IdxBufferType
&
block_idx_buffer
,
AccDataType
&
accuData
,
IndexDataType
&
accuIndex
,
index_t
thread_m_cluster_id
,
index_t
thread_k_cluster_id
)
{
constexpr
auto
cluster_len_shift
=
get_shift
<
KThreadClusterSize
>
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
I
();
if
(
thread_k_cluster_id
%
(
indOffset
*
2
)
==
0
)
{
// consider the thread clusters order, ensure the contiguous locations are accessed
// by contiguous Thread-ID
index_t
offset1
=
ReorderThreadClusters
?
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
))
:
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
));
index_t
offset2
=
ReorderThreadClusters
?
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
(
thread_k_cluster_id
+
indOffset
)
*
MThreadClusterSize
+
thread_m_cluster_id
))
:
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
*
KThreadClusterSize
+
(
thread_k_cluster_id
+
indOffset
)));
AccDataType
opData1
=
type_convert
<
AccDataType
>
(
block_val_buffer
[
offset1
]);
AccDataType
opData2
=
type_convert
<
AccDataType
>
(
block_val_buffer
[
offset2
]);
IndexDataType
currIndex1
=
block_idx_buffer
[
offset1
];
IndexDataType
currIndex2
=
block_idx_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
,
currIndex1
,
currIndex2
);
block_val_buffer
(
offset1
)
=
type_convert
<
AccDataType
>
(
opData1
);
block_idx_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
});
index_t
offset
=
ReorderThreadClusters
?
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
))
:
buffer_1d_desc
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
*
KThreadClusterSize
));
accuData
=
type_convert
<
AccDataType
>
(
block_val_buffer
[
offset
]);
accuIndex
=
block_idx_buffer
[
offset
];
}
};
};
// end of namespace ck
#endif
device
_operation/
includ
e/conv_utils.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/conv_utils.hpp
View file @
9dce6851
File moved
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
0 → 100644
View file @
9dce6851
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
enum
ConvolutionBackwardDataSpecialization_t
{
Default
,
Filter1x1Stride1Pad0
,
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
device
_operation/
includ
e/convolution_forward_specialization.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/convolution_forward_specialization.hpp
View file @
9dce6851
File moved
device
_operation/
includ
e/convolution_utility.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/convolution_utility.hpp
View file @
9dce6851
File moved
device
_operation/
includ
e/device_base.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_base.hpp
View file @
9dce6851
File moved
device
_operation/
includ
e/device_batched_gemm_xdl.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_batched_gemm_xdl.hpp
View file @
9dce6851
File moved
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
9dce6851
#ifndef DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "device_conv_backward_weight.hpp"
#include "convolution_forward_specialization.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvWrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
CDataType
=
WeiDataType
;
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CElementwiseOperation
=
WeiElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
using
ABCGridDescs
=
decltype
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
1
));
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum_t
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
CGridDesc_M_N
{},
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
output_spatial_lengths_
{
output_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
k_batch_
);
a_grid_desc_kbatch_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
InElementwiseOperation
a_element_op_
;
OutElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
index_t
k_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
void
ShowInfo
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
}
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
,
kbatch
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
>
0
)
{
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
if
(
kbatch
>
1
||
nrepeat
<=
0
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
};
if
(
has_main_k0_block_loop
)
{
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
}
else
{
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
BBlockTransferSrcVectorDim
==
2
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CBlockTransferScalarPerVector_NWaveNPerXdl
==
0
))
{
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
9dce6851
#ifndef DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "device_conv_bwd_data.hpp"
#include "convolution_backward_data_specialization.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardDataSpecialization_t
ConvBackwardDataSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvBwdData
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
ADataType
=
OutDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
InDataType
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static_assert
((
K1
%
ABlockTransferThreadClusterLengths_K0_M_K1
{}[
I2
])
%
ABlockTransferSrcScalarPerVector
==
0
);
static_assert
((
NPerBlock
/
BBlockTransferThreadClusterLengths_K0_N_K1
{}[
I1
])
%
BBlockTransferSrcScalarPerVector
==
0
);
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
index_t
i_ytilda
,
index_t
i_xtilda
)
{
using
namespace
ck
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
auto
K0
=
K
/
K1
;
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
const
auto
wei_k_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
const
auto
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const
auto
IHTildaSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilda
-
I1
)),
ConvStrideH
);
const
auto
IWTildaSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilda
-
I1
)),
ConvStrideW
);
const
auto
IHTildaSliceEnd
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildaSliceEnd
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildaSlice
=
IHTildaSliceEnd
-
IHTildaSliceBegin
;
const
auto
WTildaSlice
=
IWTildaSliceEnd
-
IWTildaSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilda
,
YTilda
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilda
,
XTilda
);
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilda
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilda
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilda
,
IHTildaSliceBegin
,
HTildaSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilda
,
IWTildaSliceBegin
,
WTildaSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilda
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilda
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilda
),
make_freeze_transform
(
i_xtilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilda
,
HTilda
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilda
,
WTilda
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilda
),
make_slice_transform
(
HTilda
,
IHTildaSliceBegin
,
HTildaSlice
),
make_freeze_transform
(
i_xtilda
),
make_slice_transform
(
WTilda
,
IWTildaSliceBegin
,
WTildaSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildaslice_wtildaslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// function end
using
ABCGridDescs
=
decltype
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
0
,
0
));
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum_t
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_in_grid
},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
output_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
for
(
index_t
i_ytilda
=
0
;
i_ytilda
<
YTilda
;
++
i_ytilda
)
{
for
(
index_t
i_xtilda
=
0
;
i_xtilda
<
XTilda
;
++
i_xtilda
)
{
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
i_ytilda
,
i_xtilda
);
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01
,
N01
));
}
}
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
std
::
vector
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
;
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
index_t
M01_
;
index_t
N01_
;
OutElementwiseOperation
a_element_op_
;
WeiElementwiseOperation
b_element_op_
;
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
index_t
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
nrepeat
=
1
;
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_container_{"
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_container_{ "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I2
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I3
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I4
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I5
)
<<
" ) "
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
],
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
ave_time
+=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
],
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_container_
[
i
]);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
+=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
],
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_container_
[
i
]);
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 pad = 0 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
BBlockTransferSrcVectorDim
==
1
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CThreadTransferDstScalarPerVector
==
0
))
{
return
false
;
}
// Gridwise GEMM size
for
(
int
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
],
arg
.
M01_
,
arg
.
N01_
))
{
return
false
;
}
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
device
_operation/
includ
e/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
9dce6851
...
@@ -549,8 +549,11 @@ struct
...
@@ -549,8 +549,11 @@ struct
Conv_N_
{
N
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_K_
{
K
},
Conv_C_
{
C
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
output_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
input_right_pads_
{
input_right_pads
}
{
{
...
@@ -625,8 +628,11 @@ struct
...
@@ -625,8 +628,11 @@ struct
index_t
Conv_N_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
input_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_dilations_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
};
};
...
@@ -638,6 +644,28 @@ struct
...
@@ -638,6 +644,28 @@ struct
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
#if 0
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
<< arg.conv_filter_strides_[1] << ", " << std::endl;
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...
@@ -656,6 +684,7 @@ struct
...
@@ -656,6 +684,7 @@ struct
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
device
_operation/
includ
e/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
9dce6851
...
@@ -526,8 +526,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -526,8 +526,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
Conv_N_
{
N
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_K_
{
K
},
Conv_C_
{
C
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
output_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
input_right_pads_
{
input_right_pads
}
{
{
...
@@ -590,8 +593,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -590,8 +593,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
index_t
Conv_N_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
input_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_dilations_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
};
};
...
@@ -603,6 +609,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -603,6 +609,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
#if 0
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
<< arg.conv_filter_strides_[1] << ", " << std::endl;
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...
@@ -618,6 +646,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -618,6 +646,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
device
_operation/
includ
e/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
9dce6851
...
@@ -422,6 +422,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -422,6 +422,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BlockSize
,
BlockSize
,
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
// TODO: Add ShuffleType for DeviceConv2d
CDataType
,
CDataType
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
...
@@ -497,8 +498,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -497,8 +498,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
Conv_N_
{
N
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_K_
{
K
},
Conv_C_
{
C
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
output_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
input_right_pads_
{
input_right_pads
}
{
{
...
@@ -550,8 +554,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -550,8 +554,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
index_t
Conv_N_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
input_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_dilations_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
};
};
...
@@ -563,6 +570,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -563,6 +570,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
#if 0
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
<< arg.conv_filter_strides_[1] << ", " << std::endl;
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...
@@ -597,6 +626,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -597,6 +626,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
.GetLength(I5)
.GetLength(I5)
<< "}" << std::endl;
<< "}" << std::endl;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
device
_operation/
includ
e/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
9dce6851
...
@@ -452,6 +452,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -452,6 +452,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
#if 0
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...
@@ -464,6 +465,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -464,6 +465,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
device
_operation/
includ
e/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
View file @
9dce6851
File moved
device
_operation/
includ
e/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
9dce6851
File moved
include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp
0 → 100644
View file @
9dce6851
#ifndef DEVICE_CONV_WRW_HPP
#define DEVICE_CONV_WRW_HPP
#include <iostream>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvWrw
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
const
void
*
p_out
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvWrwPtr
=
std
::
unique_ptr
<
DeviceConvWrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp
0 → 100644
View file @
9dce6851
#ifndef DEVICE_CONV_BWD_DATA_HPP
#define DEVICE_CONV_BWD_DATA_HPP
#include <iostream>
#include "device_base.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvBwdData
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_in
,
const
void
*
p_wei
,
const
void
*
p_out
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvBwdDataPtr
=
std
::
unique_ptr
<
DeviceConvBwdData
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
device
_operation/
includ
e/device_conv_fwd.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv_fwd.hpp
View file @
9dce6851
File moved
device
_operation/
includ
e/device_conv_fwd_bias_activation.hpp
→
include/ck/tensor
_operation/
gpu/devic
e/device_conv_fwd_bias_activation.hpp
View file @
9dce6851
File moved
Prev
1
…
3
4
5
6
7
8
9
10
11
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment