Commit b0a36aa5 authored by rocking's avatar rocking
Browse files

Refine naming and IsSupportedArgument()

parent 6b94f903
...@@ -25,13 +25,13 @@ struct DevicePoolFwd : public BaseOperator ...@@ -25,13 +25,13 @@ struct DevicePoolFwd : public BaseOperator
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_n_c_wis_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_xs_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_n_c_wos_lengths,
std::vector<ck::index_t> input_stride, std::vector<ck::index_t> input_n_c_wis_stride,
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_n_c_wis_stride,
std::vector<ck::index_t> indices_stride, std::vector<ck::index_t> indices_n_c_wis_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_xs_strides,
std::vector<ck::index_t> window_dilations, std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
......
...@@ -56,45 +56,45 @@ struct DevicePool3dFwdImpl ...@@ -56,45 +56,45 @@ struct DevicePool3dFwdImpl
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_lengths, static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_ncdhw_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_ncdhw_lengths,
std::vector<ck::index_t> input_stride, std::vector<ck::index_t> input_ncdhw_stride,
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_ncdhw_stride,
std::vector<ck::index_t> window_spatial_lengths, std::vector<ck::index_t> window_spatial_zyx_lengths,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_zyx_strides,
std::vector<ck::index_t> window_dilations, std::vector<ck::index_t> window_zyx_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_dhw_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_right_dhw_pads)
{ {
const index_t N = input_lengths[0]; const index_t N = input_ncdhw_lengths[0];
const index_t C = input_lengths[1]; const index_t C = input_ncdhw_lengths[1];
const index_t Di = input_lengths[2]; const index_t Di = input_ncdhw_lengths[2];
const index_t Hi = input_lengths[3]; const index_t Hi = input_ncdhw_lengths[3];
const index_t Wi = input_lengths[4]; const index_t Wi = input_ncdhw_lengths[4];
const index_t Do = output_lengths[2]; const index_t Do = output_ncdhw_lengths[2];
const index_t Ho = output_lengths[3]; const index_t Ho = output_ncdhw_lengths[3];
const index_t Wo = output_lengths[4]; const index_t Wo = output_ncdhw_lengths[4];
const index_t Z = window_spatial_lengths[0]; const index_t Z = window_spatial_zyx_lengths[0];
const index_t Y = window_spatial_lengths[1]; const index_t Y = window_spatial_zyx_lengths[1];
const index_t X = window_spatial_lengths[2]; const index_t X = window_spatial_zyx_lengths[2];
const index_t WindowStrideD = window_strides[0]; const index_t WindowStrideD = window_zyx_strides[0];
const index_t WindowStrideH = window_strides[1]; const index_t WindowStrideH = window_zyx_strides[1];
const index_t WindowStrideW = window_strides[2]; const index_t WindowStrideW = window_zyx_strides[2];
const index_t WindowDilationD = window_dilations[0]; const index_t WindowDilationD = window_zyx_dilations[0];
const index_t WindowDilationH = window_dilations[1]; const index_t WindowDilationH = window_zyx_dilations[1];
const index_t WindowDilationW = window_dilations[2]; const index_t WindowDilationW = window_zyx_dilations[2];
const index_t InLeftPadD = input_left_pads[0]; const index_t InLeftPadD = input_left_dhw_pads[0];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadH = input_left_dhw_pads[1];
const index_t InLeftPadW = input_left_pads[2]; const index_t InLeftPadW = input_left_dhw_pads[2];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadD = input_right_dhw_pads[0];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadH = input_right_dhw_pads[1];
const index_t InRightPadW = input_right_pads[2]; const index_t InRightPadW = input_right_dhw_pads[2];
const index_t MRaw = N * Do * Ho * Wo * C; const index_t MRaw = N * Do * Ho * Wo * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
...@@ -103,11 +103,11 @@ struct DevicePool3dFwdImpl ...@@ -103,11 +103,11 @@ struct DevicePool3dFwdImpl
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// A[ReduceM, ReduceK] // A[ReduceM, ReduceK]
const index_t Ni_stride = input_stride[0]; const index_t Ni_stride = input_ncdhw_stride[0];
const index_t Ci_stride = input_stride[1]; const index_t Ci_stride = input_ncdhw_stride[1];
const index_t Di_stride = input_stride[2]; const index_t Di_stride = input_ncdhw_stride[2];
const index_t Hi_stride = input_stride[3]; const index_t Hi_stride = input_ncdhw_stride[3];
const index_t Wi_stride = input_stride[4]; const index_t Wi_stride = input_ncdhw_stride[4];
const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor( const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C), make_tuple(N, Di, Hi, Wi, C),
...@@ -152,11 +152,11 @@ struct DevicePool3dFwdImpl ...@@ -152,11 +152,11 @@ struct DevicePool3dFwdImpl
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM] // B[ReduceM]
const index_t No_stride = output_stride[0]; const index_t No_stride = output_ncdhw_stride[0];
const index_t Co_stride = output_stride[1]; const index_t Co_stride = output_ncdhw_stride[1];
const index_t Do_stride = output_stride[2]; const index_t Do_stride = output_ncdhw_stride[2];
const index_t Ho_stride = output_stride[3]; const index_t Ho_stride = output_ncdhw_stride[3];
const index_t Wo_stride = output_stride[4]; const index_t Wo_stride = output_ncdhw_stride[4];
const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor( const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C), make_tuple(N, Di, Hi, Wi, C),
...@@ -188,40 +188,41 @@ struct DevicePool3dFwdImpl ...@@ -188,40 +188,41 @@ struct DevicePool3dFwdImpl
Argument(const InDataType* p_in_dev, Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev, OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev, IndexDataType* p_out_indices_dev,
std::vector<ck::index_t>& input_lengths, std::vector<ck::index_t>& input_ncdhw_lengths,
std::vector<ck::index_t>& output_lengths, std::vector<ck::index_t>& output_ncdhw_lengths,
std::vector<ck::index_t>& input_stride, std::vector<ck::index_t>& input_ncdhw_stride,
std::vector<ck::index_t>& output_stride, std::vector<ck::index_t>& output_ncdhw_stride,
std::vector<ck::index_t>&, // indices_stride std::vector<ck::index_t>&, // indices_ncdhw_stride
std::vector<ck::index_t>& window_spatial_lengths, std::vector<ck::index_t>& window_spatial_zyx_lengths,
std::vector<ck::index_t>& window_strides, std::vector<ck::index_t>& window_zyx_strides,
std::vector<ck::index_t>& window_dilations, std::vector<ck::index_t>& window_zyx_dilations,
std::vector<ck::index_t>& input_left_pads, std::vector<ck::index_t>& input_left_dhw_pads,
std::vector<ck::index_t>& input_right_pads) std::vector<ck::index_t>& input_right_dhw_pads)
: p_in_dev_{p_in_dev}, : p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev}, p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev}, p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{}, a_grid_desc_m_k_{},
b_grid_desc_m_{} b_grid_desc_m_{},
input_ncdhw_lengths_{input_ncdhw_lengths},
output_ncdhw_lengths_{output_ncdhw_lengths},
input_ncdhw_stride_{input_ncdhw_stride},
output_ncdhw_stride_{output_ncdhw_stride}
{ {
const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_lengths, const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths,
output_lengths, output_ncdhw_lengths,
input_stride, input_ncdhw_stride,
output_stride, output_ncdhw_stride,
window_spatial_lengths, window_spatial_zyx_lengths,
window_strides, window_zyx_strides,
window_dilations, window_zyx_dilations,
input_left_pads, input_left_dhw_pads,
input_right_pads); input_right_dhw_pads);
a_grid_desc_m_k_ = descs[I0]; a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1]; b_grid_desc_m_ = descs[I1];
// C int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] *
invariant_lowest_length_ = input_lengths[1]; window_spatial_zyx_lengths[2];
int32_t reduceLength =
window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2];
std::tie(in_element_op_, acc_element_op_) = std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength); reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
...@@ -232,11 +233,15 @@ struct DevicePool3dFwdImpl ...@@ -232,11 +233,15 @@ struct DevicePool3dFwdImpl
IndexDataType* p_out_indices_dev_; IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_; BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_; AccElementwiseOperation acc_element_op_;
// for checking vector load/store // for checking vector load/store
ck::index_t invariant_lowest_length_; std::vector<ck::index_t> input_ncdhw_lengths_;
std::vector<ck::index_t> output_ncdhw_lengths_;
std::vector<ck::index_t> input_ncdhw_stride_;
std::vector<ck::index_t> output_ncdhw_stride_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -313,9 +318,30 @@ struct DevicePool3dFwdImpl ...@@ -313,9 +318,30 @@ struct DevicePool3dFwdImpl
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) // Reduced dimension = [D, H, W]
if constexpr(IsFastestDimReduced)
{
// One of [D, H, W] should be fastest dimension
if(pArg->input_ncdhw_stride_[2] != 1 && pArg->input_ncdhw_stride_[3] != 1 &&
pArg->input_ncdhw_stride_[4] != 1)
return false;
}
else
{
// One of [N, C] should be fastest dimension
if(pArg->input_ncdhw_stride_[0] != 1 && pArg->input_ncdhw_stride_[1] != 1)
return false;
}
for(int i = 0; i < InOutRank; ++i)
{ {
return false; if(pArg->input_ncdhw_stride_[i] == 1 &&
pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
if(pArg->output_ncdhw_stride_[i] == 1 &&
pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
} }
return true; return true;
...@@ -325,43 +351,44 @@ struct DevicePool3dFwdImpl ...@@ -325,43 +351,44 @@ struct DevicePool3dFwdImpl
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_ncdhw_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_zyx_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_ncdhw_lengths,
std::vector<ck::index_t> input_stride, std::vector<ck::index_t> input_ncdhw_stride,
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_ncdhw_stride,
std::vector<ck::index_t> indices_stride, std::vector<ck::index_t> indices_ncdhw_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_zyx_strides,
std::vector<ck::index_t> window_dilations, std::vector<ck::index_t> window_zyx_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_dhw_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_dhw_pads,
std::vector<ck::index_t> pooling_dims) override std::vector<ck::index_t> pooling_dims) override
{ {
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank ||
window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank || window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank ||
input_right_pads.size() != WindowRank) input_right_dhw_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3, 4}) if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far"); throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
if(output_stride != indices_stride) if(output_ncdhw_stride != indices_ncdhw_stride)
throw std::runtime_error("output_stride need to be equal to indices_stride for now"); throw std::runtime_error(
"output_ncdhw_stride need to be equal to indices_ncdhw_stride for now");
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev), static_cast<OutDataType*>(p_out_dev),
static_cast<IndexDataType*>(p_out_indices_dev), static_cast<IndexDataType*>(p_out_indices_dev),
input_lengths, input_ncdhw_lengths,
output_lengths, output_ncdhw_lengths,
input_stride, input_ncdhw_stride,
output_stride, output_ncdhw_stride,
indices_stride, indices_ncdhw_stride,
window_lengths, window_zyx_lengths,
window_strides, window_zyx_strides,
window_dilations, window_zyx_dilations,
input_left_pads, input_left_dhw_pads,
input_right_pads); input_right_dhw_pads);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment