Commit 0f840256 authored by rocking's avatar rocking
Browse files

Extract pad

parent ecdfe960
...@@ -21,16 +21,9 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -21,16 +21,9 @@ struct DeviceBinaryElementwise : public BaseOperator
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static auto MakeDescriptor_M0_1d(const std::vector<int>& shape, template <typename Desc_M0>
const std::vector<int>& stride, static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t threadPerBlock)
index_t gridSize,
index_t threadPerBlock)
{ {
// 1d desc - [m]
const auto desc_m0 =
make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0]));
// pad
const auto m0 = desc_m0.GetLength(I0); const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0; const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
...@@ -42,6 +35,17 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -42,6 +35,17 @@ struct DeviceBinaryElementwise : public BaseOperator
return desc_m0_pad; return desc_m0_pad;
} }
static auto MakeDescriptor_M0_1d(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
const auto desc_m0 =
make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0]));
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
}
static auto MakeDescriptor_M0_2d(const std::vector<int>& shape, static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
const std::vector<int>& stride, const std::vector<int>& stride,
index_t gridSize, index_t gridSize,
...@@ -61,16 +65,7 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -61,16 +65,7 @@ struct DeviceBinaryElementwise : public BaseOperator
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
// pad return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
} }
static auto MakeDescriptor_M0(const std::vector<int>& shape, static auto MakeDescriptor_M0(const std::vector<int>& shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment