Commit 980ed33a authored by rocking's avatar rocking
Browse files

Refine deviceop

parent bb314592
...@@ -22,58 +22,58 @@ template <typename XDataType, ...@@ -22,58 +22,58 @@ template <typename XDataType,
typename OutDataType, typename OutDataType,
typename ComputeDataType, typename ComputeDataType,
typename OutElementwiseFunctor, typename OutElementwiseFunctor,
index_t Dim, index_t NDim,
index_t M0PerThread, index_t MPerThread,
index_t XScalarPerVector = M0PerThread, index_t XScalarPerVector,
index_t MeanScalarPerVector = M0PerThread, index_t MeanScalarPerVector,
index_t MeanSquareScalarPerVector = M0PerThread, index_t MeanSquareScalarPerVector,
index_t GammaScalarPerVector = M0PerThread, index_t GammaScalarPerVector,
index_t BetaScalarPerVector = M0PerThread> index_t BetaScalarPerVector>
struct DeviceNormalize_Xdl_CShuffle : public BaseOperator struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
template <typename Desc_M0> template <typename Desc_M>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{ {
const auto m0 = desc_m0.GetLength(I0); const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * M0PerThread; const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0; const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m0_pad = const auto desc_m_pad =
transform_tensor_descriptor(desc_m0, transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m0, pad)), make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return desc_m0_pad; return desc_m_pad;
} }
static auto MakeDescriptor_M0(const std::vector<index_t>& shape, static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& stride, const std::vector<index_t>& stride,
index_t gridSize, index_t gridSize,
index_t blockSize) index_t blockSize)
{ {
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{}); auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{}); auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NDim>{});
// nd desc - [s0, s1, s2, ...] // nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...] // merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1) if constexpr(NDim > 1)
{ {
const auto desc_m0 = transform_tensor_descriptor( const auto desc_m = transform_tensor_descriptor(
desc, desc,
make_tuple(make_merge_transform(tupleOfShape)), make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})), make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
} }
else else
return PadDescriptor_M0_1d(desc, gridSize, blockSize); return PadDescriptor_M_1d(desc, gridSize, blockSize);
} }
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); using GridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -83,7 +83,7 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator ...@@ -83,7 +83,7 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
OutDataType* p_output, OutDataType* p_output,
const std::vector<index_t>& shape, const std::vector<index_t>& lengths,
const std::vector<index_t>& stride_x, const std::vector<index_t>& stride_x,
const std::vector<index_t>& stride_mean, const std::vector<index_t>& stride_mean,
const std::vector<index_t>& stride_mean_square, const std::vector<index_t>& stride_mean_square,
...@@ -97,7 +97,7 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator ...@@ -97,7 +97,7 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_output_(p_output), p_output_(p_output),
shape_(shape), lengths_(lengths),
stride_x_(stride_x), stride_x_(stride_x),
stride_mean_(stride_mean), stride_mean_(stride_mean),
stride_mean_square_(stride_mean_square), stride_mean_square_(stride_mean_square),
...@@ -107,13 +107,13 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator ...@@ -107,13 +107,13 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
blockSize_(256), blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{ {
x_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_x, gridSize_, blockSize_); x_grid_desc_m_ = MakeDescriptor_M(lengths, stride_x, gridSize_, blockSize_);
mean_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_mean, gridSize_, blockSize_); mean_grid_desc_m_ = MakeDescriptor_M(lengths, stride_mean, gridSize_, blockSize_);
mean_square_grid_desc_m0_ = mean_square_grid_desc_m_ =
MakeDescriptor_M0(shape, stride_mean_square, gridSize_, blockSize_); MakeDescriptor_M(lengths, stride_mean_square, gridSize_, blockSize_);
gamma_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_gamma, gridSize_, blockSize_); gamma_grid_desc_m_ = MakeDescriptor_M(lengths, stride_gamma, gridSize_, blockSize_);
beta_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_beta, gridSize_, blockSize_); beta_grid_desc_m_ = MakeDescriptor_M(lengths, stride_beta, gridSize_, blockSize_);
output_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_output, gridSize_, blockSize_); output_grid_desc_m_ = MakeDescriptor_M(lengths, stride_output, gridSize_, blockSize_);
} }
const XDataType* p_x_; const XDataType* p_x_;
...@@ -122,13 +122,13 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator ...@@ -122,13 +122,13 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
OutDataType* p_output_; OutDataType* p_output_;
std::vector<index_t> shape_; std::vector<index_t> lengths_;
GridDesc_M0 x_grid_desc_m0_; GridDesc_M x_grid_desc_m_;
GridDesc_M0 mean_grid_desc_m0_; GridDesc_M mean_grid_desc_m_;
GridDesc_M0 mean_square_grid_desc_m0_; GridDesc_M mean_square_grid_desc_m_;
GridDesc_M0 gamma_grid_desc_m0_; GridDesc_M gamma_grid_desc_m_;
GridDesc_M0 beta_grid_desc_m0_; GridDesc_M beta_grid_desc_m_;
GridDesc_M0 output_grid_desc_m0_; GridDesc_M output_grid_desc_m_;
std::vector<index_t> stride_x_; std::vector<index_t> stride_x_;
std::vector<index_t> stride_mean_; std::vector<index_t> stride_mean_;
std::vector<index_t> stride_mean_square_; std::vector<index_t> stride_mean_square_;
...@@ -157,18 +157,6 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator ...@@ -157,18 +157,6 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
} }
}; };
bool IsScalarPerVectorValid(bool broadcastOnFastest, int scalarPerVector)
{
bool ret = true;
if(broadcastOnFastest)
ret = scalarPerVector == 1;
else
ret = M0PerThread % scalarPerVector == 0;
return ret;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
...@@ -176,26 +164,37 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator ...@@ -176,26 +164,37 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
if(pArg == nullptr) if(pArg == nullptr)
return false; return false;
if(pArg->shape_.size() != Dim) if(pArg->lengths_.size() != NDim)
return false; return false;
if(pArg->shape_.back() % M0PerThread != 0) if(pArg->lengths_.back() % MPerThread != 0)
return false; return false;
if(!IsScalarPerVectorValid(pArg->stride_x_.back() == 0, XScalarPerVector)) auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = MPerThread % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(pArg->stride_x_.back() == 1, XScalarPerVector))
return false; return false;
if(!IsScalarPerVectorValid(pArg->stride_mean_.back() == 0, MeanScalarPerVector)) if(!IsScalarPerVectorValid(pArg->stride_mean_.back() == 1, MeanScalarPerVector))
return false; return false;
if(!IsScalarPerVectorValid(pArg->stride_mean_square_.back() == 0, if(!IsScalarPerVectorValid(pArg->stride_mean_square_.back() == 1,
MeanSquareScalarPerVector)) MeanSquareScalarPerVector))
return false; return false;
if(!IsScalarPerVectorValid(pArg->stride_gamma_.back() == 0, GammaScalarPerVector)) if(!IsScalarPerVectorValid(pArg->stride_gamma_.back() == 1, GammaScalarPerVector))
return false; return false;
if(!IsScalarPerVectorValid(pArg->stride_beta_.back() == 0, BetaScalarPerVector)) if(!IsScalarPerVectorValid(pArg->stride_beta_.back() == 1, BetaScalarPerVector))
return false; return false;
}; };
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment