Commit 07a673c6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents c0f698d5 ac0d8066
#ifndef CONVOLUTION_UTILITY_HPP
#define CONVOLUTION_UTILITY_HPP
#include <vector>
namespace ck {
namespace tensor_operation {
struct ConvolutionUtility
{
static std::vector<ck::index_t>
ComputeOutputSpatialLengths(std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_strides,
std::vector<ck::index_t> conv_dilations,
std::vector<ck::index_t> in_left_pads,
std::vector<ck::index_t> in_right_pads)
{
if(input_spatial_lengths.size() == 2)
{
assert(filter_spatial_lengths.size() == 2);
assert(conv_strides.size() == 2);
assert(conv_dilations.size() == 2);
assert(in_left_pads.size() == 2);
assert(in_right_pads.size() == 2);
const index_t YEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t XEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho =
(Hi + in_left_pads[0] + in_right_pads[0] - YEff) / conv_strides[0] + 1;
const index_t Wo =
(Wi + in_left_pads[1] + in_right_pads[1] - XEff) / conv_strides[1] + 1;
return {Ho, Wo};
}
else if(input_spatial_lengths.size() == 3)
{
assert(filter_spatial_lengths.size() == 3);
assert(conv_strides.size() == 3);
assert(conv_dilations.size() == 3);
assert(in_left_pads.size() == 3);
assert(in_right_pads.size() == 3);
const index_t ZEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t YEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t XEff = (filter_spatial_lengths[2] - 1) * conv_dilations[2] + 1;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do =
(Di + in_left_pads[0] + in_right_pads[0] - ZEff) / conv_strides[0] + 1;
const index_t Ho =
(Hi + in_left_pads[1] + in_right_pads[1] - YEff) / conv_strides[1] + 1;
const index_t Wo =
(Wi + in_left_pads[2] + in_right_pads[2] - XEff) / conv_strides[2] + 1;
return {Do, Ho, Wo};
}
else
{
return {};
}
}
};
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -105,7 +105,7 @@ template <typename ALayout, ...@@ -105,7 +105,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(KRaw % AK1 == 0);
...@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(KRaw % BK1 == 0);
...@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
...@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M // pad M
return transform_tensor_descriptor(d_grid_desc_mraw, return transform_tensor_descriptor(d_grid_desc_mraw,
...@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl ...@@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, M)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
} }
}(); }();
...@@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl ...@@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, K)); return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
} }
}(); }();
...@@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl ...@@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, M)); return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
} }
}(); }();
...@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl ...@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -52,10 +52,13 @@ template <typename InDataType, ...@@ -52,10 +52,13 @@ template <typename InDataType,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvBwdWeight<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{ {
using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; using DeviceOp =
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
static constexpr index_t NDimSpatial = 2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -209,7 +210,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -209,7 +210,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -250,7 +251,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -250,7 +251,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -25,7 +25,7 @@ template <typename InDataType, ...@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization, ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
...@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -27,7 +27,7 @@ template < ...@@ -27,7 +27,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -125,7 +125,7 @@ struct ...@@ -125,7 +125,7 @@ struct
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -179,7 +179,7 @@ struct ...@@ -179,7 +179,7 @@ struct
resi_grid_desc_gemmm_gemmn); resi_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -249,7 +249,7 @@ struct ...@@ -249,7 +249,7 @@ struct
bias_grid_desc_gemmm_gemmn, bias_grid_desc_gemmm_gemmn,
resi_grid_desc_gemmm_gemmn); resi_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -466,7 +466,7 @@ struct ...@@ -466,7 +466,7 @@ struct
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -811,7 +811,7 @@ struct ...@@ -811,7 +811,7 @@ struct
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -823,7 +823,7 @@ struct ...@@ -823,7 +823,7 @@ struct
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -27,8 +27,8 @@ template < ...@@ -27,8 +27,8 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation, InMemoryDataOperationEnum OutGlobalMemoryDataOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
bias_grid_desc_gemmm_gemmn); bias_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn); bias_grid_desc_gemmm_gemmn);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -26,7 +26,7 @@ template < ...@@ -26,7 +26,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const auto GemmMPad = GemmM - GemmMRaw; const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0 { // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0 { // 1x1, pad=0
const index_t GemmK = Y * X * C; const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0); assert(GemmK % GemmK1Number == 0);
...@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc); out_gemmm_gemmn_grid_desc);
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC) else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value { // C = odd value
const index_t GemmKRaw = Y * X * C; const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
...@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
AccDataType, AccDataType,
CDataType, // TODO: Add ShuffleType for DeviceConv2d CDataType, // TODO: Add ShuffleType for DeviceConv2d
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
...@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
...@@ -83,7 +83,7 @@ template <typename InDataType, ...@@ -83,7 +83,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
const index_t Ho = output_spatial_lengths[1]; const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2]; const index_t Wo = output_spatial_lengths[2];
static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Default, static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default,
"Wrong! This specialization not implemented!"); "Wrong! This specialization not implemented!");
const auto in_desc_n_di_hi_wi_c = const auto in_desc_n_di_hi_wi_c =
...@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
InDataType, InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment