Commit 07a673c6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents c0f698d5 ac0d8066
#ifndef CONVOLUTION_UTILITY_HPP
#define CONVOLUTION_UTILITY_HPP
#include <vector>
namespace ck {
namespace tensor_operation {
struct ConvolutionUtility
{
static std::vector<ck::index_t>
ComputeOutputSpatialLengths(std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_strides,
std::vector<ck::index_t> conv_dilations,
std::vector<ck::index_t> in_left_pads,
std::vector<ck::index_t> in_right_pads)
{
if(input_spatial_lengths.size() == 2)
{
assert(filter_spatial_lengths.size() == 2);
assert(conv_strides.size() == 2);
assert(conv_dilations.size() == 2);
assert(in_left_pads.size() == 2);
assert(in_right_pads.size() == 2);
const index_t YEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t XEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho =
(Hi + in_left_pads[0] + in_right_pads[0] - YEff) / conv_strides[0] + 1;
const index_t Wo =
(Wi + in_left_pads[1] + in_right_pads[1] - XEff) / conv_strides[1] + 1;
return {Ho, Wo};
}
else if(input_spatial_lengths.size() == 3)
{
assert(filter_spatial_lengths.size() == 3);
assert(conv_strides.size() == 3);
assert(conv_dilations.size() == 3);
assert(in_left_pads.size() == 3);
assert(in_right_pads.size() == 3);
const index_t ZEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t YEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t XEff = (filter_spatial_lengths[2] - 1) * conv_dilations[2] + 1;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do =
(Di + in_left_pads[0] + in_right_pads[0] - ZEff) / conv_strides[0] + 1;
const index_t Ho =
(Hi + in_left_pads[1] + in_right_pads[1] - YEff) / conv_strides[1] + 1;
const index_t Wo =
(Wi + in_left_pads[2] + in_right_pads[2] - XEff) / conv_strides[2] + 1;
return {Do, Ho, Wo};
}
else
{
return {};
}
}
};
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -105,7 +105,7 @@ template <typename ALayout, ...@@ -105,7 +105,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(KRaw % AK1 == 0);
...@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(KRaw % BK1 == 0);
...@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
...@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M // pad M
return transform_tensor_descriptor(d_grid_desc_mraw, return transform_tensor_descriptor(d_grid_desc_mraw,
...@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl ...@@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, M)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
} }
}(); }();
...@@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl ...@@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, K)); return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
} }
}(); }();
...@@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl ...@@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, M)); return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
} }
}(); }();
...@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl ...@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -52,10 +52,13 @@ template <typename InDataType, ...@@ -52,10 +52,13 @@ template <typename InDataType,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvBwdWeight<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{ {
using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; using DeviceOp =
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
static constexpr index_t NDimSpatial = 2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -209,7 +210,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -209,7 +210,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -250,7 +251,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -250,7 +251,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -25,7 +25,7 @@ template <typename InDataType, ...@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization, ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -27,7 +27,7 @@ template < ...@@ -27,7 +27,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -125,7 +125,7 @@ struct ...@@ -125,7 +125,7 @@ struct
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -179,7 +179,7 @@ struct ...@@ -179,7 +179,7 @@ struct
resi_grid_desc_gemmm_gemmn); resi_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -249,7 +249,7 @@ struct ...@@ -249,7 +249,7 @@ struct
bias_grid_desc_gemmm_gemmn, bias_grid_desc_gemmm_gemmn,
resi_grid_desc_gemmm_gemmn); resi_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -466,7 +466,7 @@ struct ...@@ -466,7 +466,7 @@ struct
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -811,7 +811,7 @@ struct ...@@ -811,7 +811,7 @@ struct
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -823,7 +823,7 @@ struct ...@@ -823,7 +823,7 @@ struct
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -27,8 +27,8 @@ template < ...@@ -27,8 +27,8 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation, InMemoryDataOperationEnum OutGlobalMemoryDataOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
bias_grid_desc_gemmm_gemmn); bias_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn); bias_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -26,7 +26,7 @@ template < ...@@ -26,7 +26,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
AccDataType, AccDataType,
CDataType, // TODO: Add ShuffleType for DeviceConv2d CDataType, // TODO: Add ShuffleType for DeviceConv2d
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -25,7 +25,7 @@ template <typename InDataType, ...@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -119,7 +119,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -119,7 +119,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t GemmK0 = GemmK / GemmK1Number; const index_t GemmK0 = GemmK / GemmK1Number;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// A: input tensor // A: input tensor
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
...@@ -159,7 +159,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -159,7 +159,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// A: input tensor // A: input tensor
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
...@@ -316,7 +316,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -316,7 +316,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -565,7 +565,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -565,7 +565,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -577,7 +577,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -577,7 +577,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "convolution_utility.hpp" #include "conv_fwd_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_conv_fwd.hpp" #include "device_conv_fwd.hpp"
#include "common_header.hpp" #include "common_header.hpp"
...@@ -53,36 +53,30 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -53,36 +53,30 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
: N_{N}, : params_{3,
K_{K}, N,
C_{C}, K,
in_spatial_lengths_{input_spatial_lengths}, C,
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
out_spatial_lengths_{output_spatial_lengths}, out_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
p_in_{p_in}, p_in_{p_in},
p_wei_{p_wei}, p_wei_{p_wei},
p_out_{p_out}, p_out_{p_out},
in_element_op_{in_element_op}, in_element_op_{in_element_op},
wei_element_op_{wei_element_op}, wei_element_op_{wei_element_op},
out_element_op_{out_element_op} out_element_op_{out_element_op}
{ {
} }
// private: // private:
index_t N_; utils::conv::ConvParams params_;
index_t K_;
index_t C_;
std::vector<index_t> in_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> out_spatial_lengths_; std::vector<index_t> out_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> conv_filter_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
const InDataType* p_in_; const InDataType* p_in_;
const WeiDataType* p_wei_; const WeiDataType* p_wei_;
...@@ -157,13 +151,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -157,13 +151,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
std::vector<index_t> out_spatial_lengths = std::vector<index_t> out_spatial_lengths = arg.params_.GetOutputSpatialLengths();
ConvolutionUtility::ComputeOutputSpatialLengths(arg.in_spatial_lengths_,
arg.filter_spatial_lengths_,
arg.conv_filter_strides_,
arg.conv_filter_dilations_,
arg.in_left_pads_,
arg.in_right_pads_);
bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] && bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] &&
out_spatial_lengths[1] == arg.out_spatial_lengths_[1] && out_spatial_lengths[1] == arg.out_spatial_lengths_[1] &&
......
...@@ -83,7 +83,7 @@ template <typename InDataType, ...@@ -83,7 +83,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
const index_t Ho = output_spatial_lengths[1]; const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2]; const index_t Wo = output_spatial_lengths[2];
static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Default, static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default,
"Wrong! This specialization not implemented!"); "Wrong! This specialization not implemented!");
const auto in_desc_n_di_hi_wi_c = const auto in_desc_n_di_hi_wi_c =
...@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
InDataType, InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -11,7 +11,7 @@ namespace device { ...@@ -11,7 +11,7 @@ namespace device {
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct DeviceConvWrw : public BaseOperator struct DeviceConvBwdWeight : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
...@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator ...@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
using DeviceConvWrwPtr = std::unique_ptr< using DeviceConvBwdWeightPtr = std::unique_ptr<
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>; DeviceConvBwdWeight<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -25,7 +25,7 @@ template <typename InDataType, ...@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization, ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
...@@ -116,7 +116,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -116,7 +116,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -336,7 +336,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -336,7 +336,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -618,7 +618,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -618,7 +618,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
} // function end } // function end
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
...@@ -959,7 +959,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -959,7 +959,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -1385,7 +1385,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1385,7 +1385,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NumDimSpatial; i++) for(int i = 0; i < NumDimSpatial; i++)
...@@ -1527,7 +1527,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1527,7 +1527,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
<< K0PerBlock << K0PerBlock
<< ">"; << ">";
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0){ ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
str<< " Filter1x1Stride1Pad0"; str<< " Filter1x1Stride1Pad0";
} }
......
...@@ -44,7 +44,7 @@ template <typename InDataType, ...@@ -44,7 +44,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
...@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t ConvStrideW = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
...@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_wi_c_grid_desc = const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
...@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
...@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
...@@ -395,7 +395,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -395,7 +395,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t ConvStrideW = conv_filter_strides[2]; const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const auto in_gemmmraw_gemmk_grid_desc = const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
...@@ -409,7 +409,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -409,7 +409,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
...@@ -613,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -613,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -878,7 +878,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -878,7 +878,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i) for(ck::index_t i = 0; i < NumDimSpatial; ++i)
...@@ -891,7 +891,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -891,7 +891,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i) for(ck::index_t i = 0; i < NumDimSpatial; ++i)
......
...@@ -29,7 +29,7 @@ template <typename ALayout, ...@@ -29,7 +29,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -95,8 +95,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -95,8 +95,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -119,8 +119,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -119,8 +119,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(KRaw % AK1 == 0);
...@@ -136,8 +136,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -136,8 +136,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -198,8 +198,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -198,8 +198,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -222,8 +222,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -222,8 +222,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(KRaw % BK1 == 0);
...@@ -239,8 +239,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -239,8 +239,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -301,8 +301,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -301,8 +301,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
...@@ -311,8 +311,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -311,8 +311,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -321,8 +321,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -321,8 +321,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -346,10 +346,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -346,10 +346,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M // pad M
return transform_tensor_descriptor(d_grid_desc_mraw, return transform_tensor_descriptor(d_grid_desc_mraw,
...@@ -382,8 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -382,8 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -27,7 +27,7 @@ template <typename ADataType, ...@@ -27,7 +27,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -80,7 +80,7 @@ struct DeviceGemmXdl ...@@ -80,7 +80,7 @@ struct DeviceGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
...@@ -119,7 +119,7 @@ struct DeviceGemmXdl ...@@ -119,7 +119,7 @@ struct DeviceGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -154,7 +154,7 @@ struct DeviceGemmXdl ...@@ -154,7 +154,7 @@ struct DeviceGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -186,7 +186,7 @@ struct DeviceGemmXdl ...@@ -186,7 +186,7 @@ struct DeviceGemmXdl
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle ...@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -139,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d ...@@ -139,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation ...@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add ...@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment