Commit 07a673c6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents c0f698d5 ac0d8066
#ifndef CONVOLUTION_UTILITY_HPP
#define CONVOLUTION_UTILITY_HPP
#include <vector>
namespace ck {
namespace tensor_operation {
struct ConvolutionUtility
{
static std::vector<ck::index_t>
ComputeOutputSpatialLengths(std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_strides,
std::vector<ck::index_t> conv_dilations,
std::vector<ck::index_t> in_left_pads,
std::vector<ck::index_t> in_right_pads)
{
if(input_spatial_lengths.size() == 2)
{
assert(filter_spatial_lengths.size() == 2);
assert(conv_strides.size() == 2);
assert(conv_dilations.size() == 2);
assert(in_left_pads.size() == 2);
assert(in_right_pads.size() == 2);
const index_t YEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t XEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho =
(Hi + in_left_pads[0] + in_right_pads[0] - YEff) / conv_strides[0] + 1;
const index_t Wo =
(Wi + in_left_pads[1] + in_right_pads[1] - XEff) / conv_strides[1] + 1;
return {Ho, Wo};
}
else if(input_spatial_lengths.size() == 3)
{
assert(filter_spatial_lengths.size() == 3);
assert(conv_strides.size() == 3);
assert(conv_dilations.size() == 3);
assert(in_left_pads.size() == 3);
assert(in_right_pads.size() == 3);
const index_t ZEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1;
const index_t YEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1;
const index_t XEff = (filter_spatial_lengths[2] - 1) * conv_dilations[2] + 1;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do =
(Di + in_left_pads[0] + in_right_pads[0] - ZEff) / conv_strides[0] + 1;
const index_t Ho =
(Hi + in_left_pads[1] + in_right_pads[1] - YEff) / conv_strides[1] + 1;
const index_t Wo =
(Wi + in_left_pads[2] + in_right_pads[2] - XEff) / conv_strides[2] + 1;
return {Do, Ho, Wo};
}
else
{
return {};
}
}
};
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -105,7 +105,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
GemmSpecialization_t GemmSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
......@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding)
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
......@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding)
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
......@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
......@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding)
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
......@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding)
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
......@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
......@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding)
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
......@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding)
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
......@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
......@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum_t::AtomicAdd,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
......
......@@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, M));
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
......@@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, K));
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
......@@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, M));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
......@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
......
......@@ -52,10 +52,13 @@ template <typename InDataType,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWeight<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using DeviceOp =
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = OutDataType;
using BDataType = InDataType;
......@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = 2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -209,7 +210,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
......@@ -250,7 +251,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum_t::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
......@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream();
// clang-format off
str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -25,7 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization,
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
......@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
......@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
......@@ -27,7 +27,7 @@ template <
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -125,7 +125,7 @@ struct
const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0);
......@@ -179,7 +179,7 @@ struct
resi_grid_desc_gemmm_gemmn);
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0
const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0);
......@@ -249,7 +249,7 @@ struct
bias_grid_desc_gemmm_gemmn,
resi_grid_desc_gemmm_gemmn);
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value
const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
......@@ -466,7 +466,7 @@ struct
ABDataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
......@@ -811,7 +811,7 @@ struct
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......@@ -823,7 +823,7 @@ struct
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
......@@ -27,8 +27,8 @@ template <
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
InMemoryDataOperationEnum OutGlobalMemoryDataOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0);
......@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
bias_grid_desc_gemmm_gemmn);
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0
const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0);
......@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn);
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value
const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
......@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
......@@ -26,7 +26,7 @@ template <
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const auto GemmMPad = GemmM - GemmMRaw;
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ // 1x1, stride=1, pad=0
const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0);
......@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
out_gemmm_gemmn_grid_desc);
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization::Filter1x1Pad0)
{ // 1x1, pad=0
const index_t GemmK = Y * X * C;
assert(GemmK % GemmK1Number == 0);
......@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
{ // C = odd value
const index_t GemmKRaw = Y * X * C;
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
......@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
AccDataType,
CDataType, // TODO: Add ShuffleType for DeviceConv2d
CDataType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
......@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
......
......@@ -83,7 +83,7 @@ template <typename InDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Default,
static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default,
"Wrong! This specialization not implemented!");
const auto in_desc_n_di_hi_wi_c =
......@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
InDataType,
AccDataType,
OutDataType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment