Unverified Commit f91579aa authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Unified conv3D API + support for all data types. (#133)



* Convolution ND

* Code unification across dimensions for generating tensor descriptors.
* Example
* Instances

* Move convnd f32 instance file to comply with repo structure.

* Conv 1D tensor layouts.

* Formatting and use ReferenceConv

* Reference ConvFwd supporting 1D and 2D convolution.

* Debug printing TensorLayout name.

* Conv fwd 1D instance f32

* Refactor conv ND example.

Needed to support various conv dimensio.

Needed to support various conv dimensions

* Rename conv nd example director to prevent conflicts.

* Refactor some common utility to single file.

Plus some tests.

* Refactor GetHostTensorDescriptor + UT.

* Add 1D test case.

* Test reference convolution 1d/2d

* Remove some leftovers.

* Fix convolution example error for 1D

* Refactor test check errors utility function.

* Test Conv2D Fwd XDL

* More UT for 1D case.

* Parameterize input & weight initializers.

* Rename example to prevent conflicts.

* Split convnd instance into separate files for 1d/2d

* Address review comments.

* Fix data type for flops/gbytes calculations.

* Assign example number 11.

* 3D cases for convolution utility functions.

* 3D reference convolution.

* Add support for 3D convolution.

* Check for inputs bigger than  2GB.

* Formatting

* Support for bf16/f16/f32/i8 - conv instances + UT.

* Use check_err from test_util.hpp.

* Split convnd test into separate files for each dim.

* Fix data generation and use proper instances.

* Formatting

* Skip tensor initialization if not necessary.

* Fix CMakefiles.

* Remove redundant conv2d_fwd test.

* Lower problem size for conv3D UT.

* 3D case for convnd example.

* Remove leftovers after merge.

* Add Conv Specialization string to GetTypeString

* Skip instance causing numerical errors.

* Small fixes.

* Remove redundant includes.

* Fix namespace name error.

* Script for automatic testing and logging convolution fwd UTs

* Comment out numactl cmd.

* Refine weights initalization and relax rtol for fp16

* Fix weights initialization for int8.

* Add type_convert when store output in ref conv 1D.

* Get back old conv2d_fwd_xdl operation.

* Silence conv debug print.

* format

* clean

* clean

* Fix merge.

* Fix namespace for check_err
Co-authored-by: default avatarAdam Osewski <aosewski@amd.com>
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 22061366
...@@ -84,6 +84,9 @@ DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) ...@@ -84,6 +84,9 @@ DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial)
{ {
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: {
return std::make_unique<DeviceConvNDFwdInstance<3>>();
}
case 2: { case 2: {
return std::make_unique<DeviceConvNDFwdInstance<2>>(); return std::make_unique<DeviceConvNDFwdInstance<2>>();
} }
...@@ -173,6 +176,9 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t ...@@ -173,6 +176,9 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWK{});
}
case 2: { case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{}); return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{});
} }
...@@ -192,6 +198,9 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_ ...@@ -192,6 +198,9 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KZYXC{});
}
case 2: { case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{}); return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{});
} }
...@@ -211,6 +220,9 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t> ...@@ -211,6 +220,9 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{});
}
case 2: { case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{});
} }
...@@ -360,6 +372,11 @@ int main(int argc, char* argv[]) ...@@ -360,6 +372,11 @@ int main(int argc, char* argv[])
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 3: {
auto ref_conv = ReferenceConvNDFwdInstance<3>();
verify_f(ref_conv);
break;
}
case 2: { case 2: {
auto ref_conv = ReferenceConvNDFwdInstance<2>(); auto ref_conv = ReferenceConvNDFwdInstance<2>();
verify_f(ref_conv); verify_f(ref_conv);
......
...@@ -157,6 +157,12 @@ ...@@ -157,6 +157,12 @@
#define CK_WORKAROUND_SWDEV_325164 1 #define CK_WORKAROUND_SWDEV_325164 1
#endif #endif
// workaround for verification failure ConvNd forward
// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135
#ifndef CK_WORKAROUND_GITHUB_135
#define CK_WORKAROUND_GITHUB_135 1
#endif
namespace ck { namespace ck {
enum InMemoryDataOperationEnum_t enum InMemoryDataOperationEnum_t
......
...@@ -186,6 +186,28 @@ HostTensorDescriptor GetHostTensorDescriptor(const std::vector<std::size_t>& dim ...@@ -186,6 +186,28 @@ HostTensorDescriptor GetHostTensorDescriptor(const std::vector<std::size_t>& dim
return HostTensorDescriptor( return HostTensorDescriptor(
dims, std::vector<std::size_t>{C * dims[2] * dims[3], 1, dims[3] * C, C}); dims, std::vector<std::size_t>{C * dims[2] * dims[3], 1, dims[3] * C, C});
} }
// 3D
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCDHW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCZYX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKDHW>::value)
{
return HostTensorDescriptor(dims,
std::vector<std::size_t>{C * dims[2] * dims[3] * dims[4],
dims[2] * dims[3] * dims[4],
dims[3] * dims[4],
dims[4],
1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KZYXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWK>::value)
{
return HostTensorDescriptor(
dims,
std::vector<std::size_t>{
C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C});
}
std::stringstream err_msg; std::stringstream err_msg;
err_msg << "Unsupported data layout provided: " << layout << "!"; err_msg << "Unsupported data layout provided: " << layout << "!";
......
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION #ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION #define CONVOLUTION_FORWARD_SPECIALIZATION
#include <string>
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -13,6 +15,18 @@ enum ConvolutionForwardSpecialization_t ...@@ -13,6 +15,18 @@ enum ConvolutionForwardSpecialization_t
OddC, OddC,
}; };
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization_t& s)
{
switch(s)
{
case Default: return "Default";
case Filter1x1Pad0: return "Filter1x1Pad0";
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case OddC: return "OddC";
default: return "Unrecognized specialization!";
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
......
...@@ -875,7 +875,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -875,7 +875,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -466,7 +466,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -466,7 +466,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -708,7 +707,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -708,7 +707,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -367,6 +367,155 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -367,6 +367,155 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_right_pad_transform(gemm_m, gemm_m_pad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_pass_through_transform(gemm_m)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
}
static index_t GetGemmMRaw(ck::index_t N, static index_t GetGemmMRaw(ck::index_t N,
const std::vector<ck::index_t>& output_spatial_lengths) const std::vector<ck::index_t>& output_spatial_lengths)
{ {
...@@ -445,6 +594,13 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -445,6 +594,13 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
...@@ -593,6 +749,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -593,6 +749,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -605,7 +762,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -605,7 +762,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -704,6 +861,22 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -704,6 +861,22 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// Input tensors can't be bigger than 2GB each.
constexpr std::size_t GB2 = 2 * 1e9;
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2)
{
return false;
}
if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2)
{
return false;
}
if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2)
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
...@@ -851,7 +1024,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -851,7 +1024,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -85,6 +85,7 @@ struct NKHW : public BaseTensorLayout ...@@ -85,6 +85,7 @@ struct NKHW : public BaseTensorLayout
static constexpr const char* name = "NKHW"; static constexpr const char* name = "NKHW";
}; };
// 3D Conv
struct NDHWC : public BaseTensorLayout struct NDHWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NDHWC"; static constexpr const char* name = "NDHWC";
...@@ -100,6 +101,21 @@ struct NDHWK : public BaseTensorLayout ...@@ -100,6 +101,21 @@ struct NDHWK : public BaseTensorLayout
static constexpr const char* name = "NDHWK"; static constexpr const char* name = "NDHWK";
}; };
struct NCDHW : public BaseTensorLayout
{
static constexpr const char* name = "NCDHW";
};
struct KCZYX : public BaseTensorLayout
{
static constexpr const char* name = "KCZYX";
};
struct NKDHW : public BaseTensorLayout
{
static constexpr const char* name = "NKDHW";
};
} // namespace convolution } // namespace convolution
template < template <
......
...@@ -14,9 +14,9 @@ namespace host { ...@@ -14,9 +14,9 @@ namespace host {
// //
// @brief Reference implementation for forward convolution. // @brief Reference implementation for forward convolution.
// //
// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout. // @paragraph Supports both NCHW as well as NHWC formats (and their respective
// Weights tensor supports KCYX data layout. Output tensor supports // counterparts for weight and output) as long as tensor descriptor
// NKHoWo data layout. // lengths is in NCHW.
// //
// @tparam InDataType Input tensor data type. // @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. // @tparam WeiDataType Weights tensor data type.
...@@ -100,9 +100,9 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -100,9 +100,9 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_wei; float v_wei;
arg.in_element_op_(v_in, arg.in_element_op_(v_in,
static_cast<const float>(arg.input_(n, c, wi))); ck::type_convert<float>(arg.input_(n, c, wi)));
arg.wei_element_op_(v_wei, arg.wei_element_op_(v_wei,
static_cast<const float>(arg.weight_(k, c, x))); ck::type_convert<float>(arg.weight_(k, c, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -112,7 +112,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -112,7 +112,7 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = v_out; arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
...@@ -169,6 +169,61 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -169,6 +169,61 @@ struct ReferenceConvFwd : public device::BaseOperator
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3)
{
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
{
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
{
int wi = wo * arg.conv_strides_[2] +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2];
if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x)));
v_acc += v_in * v_wei;
}
}
}
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg, int) override
......
...@@ -23,6 +23,7 @@ add_subdirectory(gemm_bias_relu_add) ...@@ -23,6 +23,7 @@ add_subdirectory(gemm_bias_relu_add)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
add_subdirectory(conv1d_fwd) add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
add_subdirectory(conv3d_fwd)
add_subdirectory(conv2d_fwd_bias_relu) add_subdirectory(conv2d_fwd_bias_relu)
add_subdirectory(conv2d_fwd_bias_relu_add) add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_fwd_bias_relu_atomic_add) add_subdirectory(conv2d_fwd_bias_relu_atomic_add)
......
# device_conv1d_fwd_instance # device_conv1d_fwd_instance
set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE
device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp;
device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp;
device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp;
device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp;
) )
add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE})
......
# device_conv3d_fwd_instance
set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp;
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp;
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp;
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
)
add_library(device_conv3d_fwd_instance SHARED ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE})
target_compile_features(device_conv3d_fwd_instance PUBLIC)
set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv3d_fwd_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv3d_fwd_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment