"cacheflow/model_executor/models/opt.py" did not exist on "39161c98a054756cb39edb9f634c6a466410c92d"
Commit 0be1cf14 authored by Chao Liu's avatar Chao Liu
Browse files

update conv bwd weight

parent b054669b
...@@ -15,6 +15,19 @@ enum struct ConvolutionBackwardWeightSpecialization ...@@ -15,6 +15,19 @@ enum struct ConvolutionBackwardWeightSpecialization
OddC, OddC,
}; };
inline std::string
getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization& s)
{
switch(s)
{
case ConvolutionBackwardWeightSpecialization::Default: return "Default";
case ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0:
return "Filter1x1Stride1Pad0";
case ConvolutionBackwardWeightSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionBackwardWeightSpecialization::OddC: return "OddC";
default: return "Unrecognized specialization!";
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION #pragma once
#define CONVOLUTION_FORWARD_SPECIALIZATION
#include <string> #include <string>
...@@ -33,4 +32,3 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp ...@@ -33,4 +32,3 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" #include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
...@@ -57,7 +57,14 @@ template <typename InDataType, ...@@ -57,7 +57,14 @@ template <typename InDataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWeight<InElementwiseOperation, : public DeviceConvBwdWeight<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
......
...@@ -55,7 +55,14 @@ template <typename InDataType, ...@@ -55,7 +55,14 @@ template <typename InDataType,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdData<InElementwiseOperation, : public DeviceConvBwdData<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
......
...@@ -4,16 +4,21 @@ ...@@ -4,16 +4,21 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename InElementwiseOperation, template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct DeviceConvBwdData : public BaseOperator struct DeviceConvBwdData : public BaseOperator
...@@ -39,12 +44,6 @@ struct DeviceConvBwdData : public BaseOperator ...@@ -39,12 +44,6 @@ struct DeviceConvBwdData : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvBwdDataPtr = std::unique_ptr<
DeviceConvBwdData<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
...@@ -12,7 +11,14 @@ namespace ck { ...@@ -12,7 +11,14 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename InElementwiseOperation, template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct DeviceConvBwdWeight : public BaseOperator struct DeviceConvBwdWeight : public BaseOperator
...@@ -39,12 +45,6 @@ struct DeviceConvBwdWeight : public BaseOperator ...@@ -39,12 +45,6 @@ struct DeviceConvBwdWeight : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvBwdWeightPtr = std::unique_ptr<
DeviceConvBwdWeight<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#pragma once #pragma once
#include <iostream>
#include <vector> #include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
...@@ -12,7 +11,7 @@ namespace ck { ...@@ -12,7 +11,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <ck::index_t NumDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
typename OutLayout, typename OutLayout,
......
...@@ -21,7 +21,8 @@ namespace tensor_operation { ...@@ -21,7 +21,8 @@ namespace tensor_operation {
namespace device { namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <typename InDataType, template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
...@@ -29,7 +30,6 @@ template <typename InDataType, ...@@ -29,7 +30,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization, ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t NumDimSpatial,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -55,12 +55,29 @@ template <typename InDataType, ...@@ -55,12 +55,29 @@ template <typename InDataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
: public DeviceConvBwdData<InElementwiseOperation, : public DeviceConvBwdData<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
using DeviceOp = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Xdl;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = WeiDataType; using BDataType = WeiDataType;
...@@ -950,7 +967,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -950,7 +967,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
{0, 0, 0}); {0, 0, 0});
} }
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
...@@ -1037,7 +1054,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1037,7 +1054,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
CreateABCDesc<NumDimSpatial>(); CreateABCDesc<NDimSpatial>();
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
...@@ -1060,7 +1077,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1060,7 +1077,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
} }
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NumDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
...@@ -1118,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1118,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
} }
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NumDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
...@@ -1186,8 +1203,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1186,8 +1203,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
} }
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
NumDimSpatial>(Conv_N_, Conv_N_,
Conv_K_, Conv_K_,
Conv_C_, Conv_C_,
input_spatial_lengths_, input_spatial_lengths_,
...@@ -1398,7 +1415,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1398,7 +1415,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NumDimSpatial; i++) for(int i = 0; i < NDimSpatial; i++)
{ {
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
...@@ -1528,7 +1545,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1528,7 +1545,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K" str << "DeviceConvNdBwdDataNwcKxcNwk_Xdl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -39,7 +39,7 @@ namespace device { ...@@ -39,7 +39,7 @@ namespace device {
// 3D: // 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// //
template <ck::index_t NumDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -74,16 +74,16 @@ template <ck::index_t NumDimSpatial, ...@@ -74,16 +74,16 @@ template <ck::index_t NumDimSpatial,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceConvNdFwdNwcKxcNwk_Xdl struct DeviceConvNdFwdNwcKxcNwk_Xdl
: public DeviceConvFwd<NumDimSpatial, : public DeviceConvFwd<NDimSpatial,
ck::tuple_element_t<NumDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC, ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>, ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NumDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC, ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>, ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NumDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK, ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK, ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>, ck::tensor_layout::convolution::NDHWK>>,
...@@ -94,27 +94,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl ...@@ -94,27 +94,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
using Base =
DeviceConvFwd<NumDimSpatial,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>;
using DeviceOp = DeviceConvNdFwdNwcKxcNwk_Xdl; using DeviceOp = DeviceConvNdFwdNwcKxcNwk_Xdl;
using ADataType = InDataType; using ADataType = InDataType;
...@@ -124,8 +103,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl ...@@ -124,8 +103,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -599,7 +576,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl ...@@ -599,7 +576,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
// C = A^T*B // C = A^T*B
// A: // A:
const auto in_gemmk0_gemmm_gemmk1_grid_desc = const auto in_gemmk0_gemmm_gemmk1_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N, GetInputTensorDescriptor<NDimSpatial>(N,
C, C,
GemmMRaw, GemmMRaw,
GemmK, GemmK,
...@@ -642,7 +619,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl ...@@ -642,7 +619,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
} }
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
...@@ -934,7 +911,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl ...@@ -934,7 +911,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i) for(ck::index_t i = 0; i < NDimSpatial; ++i)
{ {
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
...@@ -947,7 +924,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl ...@@ -947,7 +924,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
ConvolutionForwardSpecialization::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
{ {
// check if it's 1x1 conv // check if it's 1x1 conv
for(ck::index_t i = 0; i < NumDimSpatial; ++i) for(ck::index_t i = 0; i < NDimSpatial; ++i)
{ {
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 && if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0)) arg.input_right_pads_[i] == 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment