Commit 9b22bdd9 authored by rocking's avatar rocking
Browse files

get original index in max pooling

parent ed552712
...@@ -96,11 +96,11 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -96,11 +96,11 @@ static void pool_host_verify(const Tensor<InDataType>& in,
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{ {
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
wi < in.mDesc.GetLengths()[3]) wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x; IndexDataType currIndex = in.GetOffsetFromMultiIndex(n, c, hi, wi);
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
......
...@@ -110,9 +110,7 @@ static void pool3d_host_verify(const Tensor<InDataType>& in, ...@@ -110,9 +110,7 @@ static void pool3d_host_verify(const Tensor<InDataType>& in,
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[4])) wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[4]))
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, di, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, di, hi, wi));
IndexDataType currIndex = IndexDataType currIndex = in.GetOffsetFromMultiIndex(n, c, di, hi, wi);
z * window_spatial_lengths[1] * window_spatial_lengths[2] +
y * window_spatial_lengths[2] + x;
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
......
...@@ -228,17 +228,19 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<5 ...@@ -228,17 +228,19 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePoolFwd<5
InSrcOutDstVectorSize, InSrcOutDstVectorSize,
InSrcOutDstVectorSize>; InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce, const auto kernel =
OuputIndex, kernel_reduce_threadwise<gridwise_reduce,
false, // don't have index input OuputIndex,
InDataType, true, // pooling need to return global index
OutDataType, false, // don't have index input
AccDataType, InDataType,
IndexDataType, OutDataType,
AGridDesc_M_K, AccDataType,
BGridDesc_M, IndexDataType,
InElementwiseOperation, AGridDesc_M_K,
AccElementwiseOperation>; BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0); ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
......
...@@ -234,17 +234,19 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -234,17 +234,19 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
InSrcOutDstVectorSize, InSrcOutDstVectorSize,
InSrcOutDstVectorSize>; InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce, const auto kernel =
OuputIndex, kernel_reduce_threadwise<gridwise_reduce,
false, // don't have index input OuputIndex,
InDataType, true, // pooling need to return global index
OutDataType, false, // don't have index input
AccDataType, InDataType,
IndexDataType, OutDataType,
AGridDesc_M_K, AccDataType,
BGridDesc_M, IndexDataType,
InElementwiseOperation, AGridDesc_M_K,
AccElementwiseOperation>; BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0); ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0);
......
...@@ -28,6 +28,7 @@ template <typename InDataType, ...@@ -28,6 +28,7 @@ template <typename InDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
bool PropagateNan, bool PropagateNan,
bool OutputIndex, bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInputIfOutputIndex, bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
...@@ -260,6 +261,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType, ...@@ -260,6 +261,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
const auto kernel = kernel_reduce_threadwise<GridwiseReduce, const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
OutputIndex, OutputIndex,
TransformIndexKtoGlobal,
HaveIndexInput, HaveIndexInput,
InDataType, InDataType,
OutDataType, OutDataType,
......
...@@ -15,6 +15,7 @@ namespace ck { ...@@ -15,6 +15,7 @@ namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool OutputIndex, bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInput, bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
...@@ -48,16 +49,17 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -48,16 +49,17 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
} }
else else
{ {
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
out_grid_desc_m, in_grid_desc_m_k,
in_elementwise_op, out_grid_desc_m,
acc_elementwise_op, in_elementwise_op,
alpha, acc_elementwise_op,
p_in_value_global, alpha,
p_in_index_global, p_in_value_global,
beta, p_in_index_global,
p_out_value_global, beta,
p_out_index_global); p_out_value_global,
p_out_index_global);
}; };
}; };
...@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
}; };
template <bool HaveIndexInput> template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
...@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
if constexpr(TransformIndexKtoGlobal)
{
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
const auto coord = make_tensor_coordinate(
in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
accu_index_buf(I)));
accu_index_buf(I) = coord.GetOffset();
});
}
}; };
// for indiced operation, acc_elementwise_op shoud do nothing // for indiced operation, acc_elementwise_op shoud do nothing
......
...@@ -90,6 +90,7 @@ void add_device_reduce_instance_threadwise( ...@@ -90,6 +90,7 @@ void add_device_reduce_instance_threadwise(
AccElementwiseOp, AccElementwiseOp,
PropagateNan, PropagateNan,
OutputIndex, OutputIndex,
false,
false, // HaveIndexInputIfOutputIndex false, // HaveIndexInputIfOutputIndex
cfg1::BlockSize_, cfg1::BlockSize_,
cfg2::MThreadSliceSize_, cfg2::MThreadSliceSize_,
......
...@@ -411,6 +411,12 @@ struct Tensor ...@@ -411,6 +411,12 @@ struct Tensor
} }
} }
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
template <typename... Is> template <typename... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment