Commit e72c0c43 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents d714fa15 313bbea5
...@@ -459,6 +459,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -459,6 +459,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
{ {
// check slice is valid
const index_t Y = filter_spatial_lengths_[0];
const index_t X = filter_spatial_lengths_[1];
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
N, N,
K, K,
......
...@@ -875,7 +875,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -875,7 +875,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -466,7 +466,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -466,7 +466,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -708,7 +707,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -708,7 +707,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -207,41 +207,28 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -207,41 +207,28 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
const index_t Ho = output_spatial_lengths[1]; const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2]; const index_t Wo = output_spatial_lengths[2];
if constexpr(ConvForwardSpecialization == static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Default,
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) "Wrong! This specialization not implemented!");
{
static_assert(ConvForwardSpecialization == -1, "Not implemented!"); const auto in_desc_n_di_hi_wi_c =
} make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
else if constexpr(ConvForwardSpecialization == const auto wei_desc_k_z_y_x_c =
ConvolutionForwardSpecialization_t::Filter1x1Pad0) make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
{ const auto out_desc_n_do_ho_wo_k =
static_assert(ConvForwardSpecialization == -1, "Not implemented!"); make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
}
else const auto descs = transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad(
{ in_desc_n_di_hi_wi_c,
const auto in_desc_n_di_hi_wi_c = wei_desc_k_z_y_x_c,
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); out_desc_n_do_ho_wo_k,
const auto wei_desc_k_z_y_x_c = make_tuple(conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]),
make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); make_tuple(
const auto out_desc_n_do_ho_wo_k = conv_filter_dilations[0], conv_filter_dilations[1], conv_filter_dilations[2]),
make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]),
make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]),
const auto descs = Number<K1>{});
transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad(
in_desc_n_di_hi_wi_c, return descs;
wei_desc_k_z_y_x_c,
out_desc_n_do_ho_wo_k,
make_tuple(
conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]),
make_tuple(conv_filter_dilations[0],
conv_filter_dilations[1],
conv_filter_dilations[2]),
make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]),
make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]),
Number<K1>{});
return descs;
}
} }
using ABCGridDescs = remove_cvref_t<decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( using ABCGridDescs = remove_cvref_t<decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
......
...@@ -367,6 +367,155 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -367,6 +367,155 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmmraw_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_right_pad_transform(gemm_m, gemm_m_pad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
make_pass_through_transform(gemm_m)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(gemm_k0),
make_right_pad_transform(gemm_m, gemm_m_pad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
}
static index_t GetGemmMRaw(ck::index_t N, static index_t GetGemmMRaw(ck::index_t N,
const std::vector<ck::index_t>& output_spatial_lengths) const std::vector<ck::index_t>& output_spatial_lengths)
{ {
...@@ -445,6 +594,13 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -445,6 +594,13 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
...@@ -593,6 +749,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -593,6 +749,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -605,7 +762,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -605,7 +762,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -704,6 +861,22 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -704,6 +861,22 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// Input tensors can't be bigger than 2GB each.
constexpr std::size_t GB2 = 2 * 1e9;
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2)
{
return false;
}
if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2)
{
return false;
}
if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2)
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
...@@ -851,7 +1024,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -851,7 +1024,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
#ifndef DEVICE_GEMM_HPP #pragma once
#define DEVICE_GEMM_HPP
#include <iostream> #include <iostream>
#include "device_base.hpp" #include "device_base.hpp"
...@@ -8,35 +6,12 @@ namespace ck { ...@@ -8,35 +6,12 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, struct GemmShape
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBias : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> ck::index_t M, N, K;
MakeArgumentPointer(const void* p_a, ck::index_t StrideA, StrideB, StrideC;
const void* p_b,
const void* p_bias,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -65,7 +40,29 @@ template <typename AElementwiseOperation, ...@@ -65,7 +40,29 @@ template <typename AElementwiseOperation,
using DeviceGemmPtr = std::unique_ptr< using DeviceGemmPtr = std::unique_ptr<
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGroupedGemmPtr = std::unique_ptr<
DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBias : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_d0,
void* p_d1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_gemm_bias.hpp"
#include "device_gemm.hpp"
#include "device_gemm_xdl.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
......
...@@ -16,9 +16,11 @@ namespace device { ...@@ -16,9 +16,11 @@ namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <typename InElementwiseOperation, typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
virtual size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths) virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims)
{ {
(void)inLengths; (void)inLengths;
(void)reduceDims;
return (0); return (0);
}; };
...@@ -32,19 +34,19 @@ struct DeviceReduce : public BaseOperator ...@@ -32,19 +34,19 @@ struct DeviceReduce : public BaseOperator
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) = 0; const AccElementwiseOperation acc_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -36,20 +36,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -36,20 +36,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool BetaIsZero = NeedIndices;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -57,18 +57,18 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -57,18 +57,18 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides) const std::vector<int>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() { const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims) if constexpr(reduceAllDim)
{ {
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -79,6 +79,9 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -79,6 +79,9 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
} }
else else
{ {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
...@@ -93,18 +96,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -93,18 +96,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
} }
}(); }();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
...@@ -112,44 +117,45 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -112,44 +117,45 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides) const std::vector<int>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto inPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = auto out_grid_desc_m_padded = transform_tensor_descriptor(
transform_tensor_descriptor(out_grid_desc_m, out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, inPad)), make_tuple(make_right_pad_transform(invariantLength, inPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int>& inLengths, Argument(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_indices_dev,
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
...@@ -160,21 +166,21 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -160,21 +166,21 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
{ {
(void)workspace_dev; (void)workspace_dev;
std::tie(inLengths_, inStrides_) = inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths_); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
...@@ -186,7 +192,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -186,7 +192,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -278,18 +284,22 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -278,18 +284,22 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
return (false); {
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
return (false); return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0) if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false); return (false);
};
} }
else else
{ {
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1) if(pArg->inStrides_[Rank - 1] != 1)
return (false); return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0) if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
...@@ -308,19 +318,19 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -308,19 +318,19 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
......
...@@ -37,6 +37,10 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -37,6 +37,10 @@ struct DeviceReduceBlockWiseSecondCall
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool BetaIsZero = NeedIndices;
...@@ -46,12 +50,8 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -46,12 +50,8 @@ struct DeviceReduceBlockWiseSecondCall
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"); "InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -65,18 +65,20 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -65,18 +65,20 @@ struct DeviceReduceBlockWiseSecondCall
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
...@@ -84,26 +86,27 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -84,26 +86,27 @@ struct DeviceReduceBlockWiseSecondCall
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides) const std::vector<int>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = auto out_grid_desc_m_padded = transform_tensor_descriptor(
transform_tensor_descriptor(out_grid_desc_m, out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, outPad)), make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
...@@ -131,8 +134,8 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -131,8 +134,8 @@ struct DeviceReduceBlockWiseSecondCall
in_elementwise_op_(in_elementwise_op), in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op) acc_elementwise_op_(acc_elementwise_op)
{ {
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
invariant_total_length = inLengths[0]; invariant_total_length = inLengths[0];
reduce_total_length = inLengths[1]; reduce_total_length = inLengths[1];
...@@ -159,7 +162,7 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -159,7 +162,7 @@ struct DeviceReduceBlockWiseSecondCall
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -268,19 +271,19 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -268,19 +271,19 @@ struct DeviceReduceBlockWiseSecondCall
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
(void)reduceDims; (void)reduceDims;
......
...@@ -12,38 +12,30 @@ namespace ck { ...@@ -12,38 +12,30 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// template <typename preUnaryOpType, typename posUnaryOpType> // here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// using DeviceReducePtr = std::unique_ptr<DeviceReduce<preUnaryOpType, posUnaryOpType>>; // of reduce dims
template <int Rank, int NumReduceDim>
template <int Rank, typename ReduceDims>
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
{ {
static_assert(Rank <= 6, "bigger Rank size not supported!"); static_assert(Rank <= 6, "bigger Rank size not supported!");
size_t tensor_total_length = 1; size_t invariant_total_length = 1;
size_t reduce_total_length = 1; size_t reduce_total_length = 1;
static_for<0, ReduceDims::Size(), 1>{}(
[&](auto i) { reduce_total_length *= inLengths[ReduceDims::At(i)]; });
static_for<0, Rank, 1>{}([&](auto i) { tensor_total_length *= inLengths[i.value]; }); constexpr int NumInvariantDim = Rank - NumReduceDim;
return std::make_pair(tensor_total_length / reduce_total_length, reduce_total_length); for(int i = NumInvariantDim; i < Rank; i++)
}; reduce_total_length *= inLengths[i];
template <int x, typename Seq>
constexpr bool belong()
{
bool inside = false;
static_for<0, Seq::Size(), 1>{}([&](auto i) { inside = (inside || (x == Seq::At(i))); }); for(int i = 0; i < NumInvariantDim; i++)
invariant_total_length *= inLengths[i];
return (inside); return std::make_pair(invariant_total_length, reduce_total_length);
}; };
// helper functions using variadic template arguments // helper functions using variadic template arguments
template <index_t... Ns> template <index_t... Ns>
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>) auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
{ {
return make_tuple(static_cast<index_t>(lengths[Ns])...); return make_tuple(static_cast<index_t>(lengths[Ns])...);
}; };
...@@ -59,16 +51,12 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS ...@@ -59,16 +51,12 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
}; };
template <index_t Rank, index_t NumReduceDim> template <index_t Rank, index_t NumReduceDim>
static inline std::pair<std::vector<int>, std::vector<int>> std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides,
shuffle_tensor_dimensions(const std::vector<int>& dimLengths, const std::vector<int>& reduceDims)
const std::vector<int>& dimStrides,
const std::vector<int>& reduceDims)
{ {
std::vector<int> newDimLengths; std::vector<int> newLengthsStrides;
std::vector<int> newDimStrides;
assert(Rank == dimLengths.size() && Rank == dimStrides.size() && assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
NumReduceDim == reduceDims.size());
int reduceFlag = 0; int reduceFlag = 0;
...@@ -82,19 +70,17 @@ shuffle_tensor_dimensions(const std::vector<int>& dimLengths, ...@@ -82,19 +70,17 @@ shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
for(int i = 0; i < Rank; i++) for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0) if((reduceFlag & (1 << i)) == 0)
{ {
newDimLengths.push_back(dimLengths[i]); newLengthsStrides.push_back(origLengthsStrides[i]);
newDimStrides.push_back(dimStrides[i]);
}; };
// collect reduce dimensions // collect reduce dimensions
for(int i = 0; i < Rank; i++) for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) > 0) if((reduceFlag & (1 << i)) > 0)
{ {
newDimLengths.push_back(dimLengths[i]); newLengthsStrides.push_back(origLengthsStrides[i]);
newDimStrides.push_back(dimStrides[i]);
}; };
return std::make_pair(newDimLengths, newDimStrides); return newLengthsStrides;
}; };
} // namespace device } // namespace device
......
...@@ -39,18 +39,18 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -39,18 +39,18 @@ struct DeviceReduceMultiBlockAtomicAdd
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr bool support_AtomicAdd = static constexpr bool support_AtomicAdd =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value; std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
...@@ -67,18 +67,18 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -67,18 +67,18 @@ struct DeviceReduceMultiBlockAtomicAdd
int blkGroupSize, int blkGroupSize,
int kBlockTileIterations) int kBlockTileIterations)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() { const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims) if constexpr(reduceAllDim)
{ {
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -89,6 +89,9 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -89,6 +89,9 @@ struct DeviceReduceMultiBlockAtomicAdd
} }
else else
{ {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
...@@ -103,19 +106,20 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -103,19 +106,20 @@ struct DeviceReduceMultiBlockAtomicAdd
} }
}(); }();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
...@@ -123,44 +127,45 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -123,44 +127,45 @@ struct DeviceReduceMultiBlockAtomicAdd
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides) const std::vector<int>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = auto out_grid_desc_m_padded = transform_tensor_descriptor(
transform_tensor_descriptor(out_grid_desc_m, out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, outPad)), make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int>& inLengths, Argument(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_indices_dev,
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
...@@ -171,21 +176,21 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -171,21 +176,21 @@ struct DeviceReduceMultiBlockAtomicAdd
(void)out_indices_dev; (void)out_indices_dev;
(void)workspace_dev; (void)workspace_dev;
std::tie(inLengths_, inStrides_) = inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths_); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1; int iterations = 1;
while(true) while(true)
...@@ -218,7 +223,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -218,7 +223,7 @@ struct DeviceReduceMultiBlockAtomicAdd
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -334,18 +339,22 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -334,18 +339,22 @@ struct DeviceReduceMultiBlockAtomicAdd
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
return (false); {
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
return (false); return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0) if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false); return (false);
};
} }
else else
{ {
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1) if(pArg->inStrides_[Rank - 1] != 1)
return (false); return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0) if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
...@@ -371,19 +380,19 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -371,19 +380,19 @@ struct DeviceReduceMultiBlockAtomicAdd
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment