Commit 9b22bdd9 authored by rocking's avatar rocking
Browse files

get original index in max pooling

parent ed552712
......@@ -96,11 +96,11 @@ static void pool_host_verify(const Tensor<InDataType>& in,
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x;
IndexDataType currIndex = in.GetOffsetFromMultiIndex(n, c, hi, wi);
in_elementwise_op(currVal, currVal);
......
......@@ -110,9 +110,7 @@ static void pool3d_host_verify(const Tensor<InDataType>& in,
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[4]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, di, hi, wi));
IndexDataType currIndex =
z * window_spatial_lengths[1] * window_spatial_lengths[2] +
y * window_spatial_lengths[2] + x;
IndexDataType currIndex = in.GetOffsetFromMultiIndex(n, c, di, hi, wi);
in_elementwise_op(currVal, currVal);
......
......@@ -228,17 +228,19 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<5
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
OuputIndex,
false, // don't have index input
InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
const auto kernel =
kernel_reduce_threadwise<gridwise_reduce,
OuputIndex,
true, // pooling need to return global index
false, // don't have index input
InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
......
......@@ -234,17 +234,19 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
OuputIndex,
false, // don't have index input
InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
const auto kernel =
kernel_reduce_threadwise<gridwise_reduce,
OuputIndex,
true, // pooling need to return global index
false, // don't have index input
InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0);
......
......@@ -28,6 +28,7 @@ template <typename InDataType,
typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize,
index_t MThreadSliceSize,
......@@ -260,6 +261,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
OutputIndex,
TransformIndexKtoGlobal,
HaveIndexInput,
InDataType,
OutDataType,
......
......@@ -15,6 +15,7 @@ namespace ck {
template <typename GridwiseReduction,
bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInput,
typename InDataType,
typename OutDataType,
......@@ -48,16 +49,17 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
}
else
{
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_value_global,
p_in_index_global,
beta,
p_out_value_global,
p_out_index_global);
GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_value_global,
p_in_index_global,
beta,
p_out_value_global,
p_out_index_global);
};
};
......@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
};
template <bool HaveIndexInput>
template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
......@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise
indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
if constexpr(TransformIndexKtoGlobal)
{
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
const auto coord = make_tensor_coordinate(
in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
accu_index_buf(I)));
accu_index_buf(I) = coord.GetOffset();
});
}
};
// for indiced operation, acc_elementwise_op shoud do nothing
......
......@@ -90,6 +90,7 @@ void add_device_reduce_instance_threadwise(
AccElementwiseOp,
PropagateNan,
OutputIndex,
false,
false, // HaveIndexInputIfOutputIndex
cfg1::BlockSize_,
cfg2::MThreadSliceSize_,
......
......@@ -411,6 +411,12 @@ struct Tensor
}
}
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
template <typename... Is>
T& operator()(Is... is)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment