Commit 665b73ff authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add more helper methods in 'DeviceElementwise'

parent 8a1ccdd4
...@@ -222,14 +222,9 @@ struct DeviceElementwise ...@@ -222,14 +222,9 @@ struct DeviceElementwise
} }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument& arg)
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); if(arg.lengths_.back() % MPerThread != 0)
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false; return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths, auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
...@@ -247,19 +242,34 @@ struct DeviceElementwise ...@@ -247,19 +242,34 @@ struct DeviceElementwise
bool valid = true; bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
return valid; return valid;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths, inStridesArray, outStridesArray, in_dev_buffers, out_dev_buffers, elementwise_op};
}
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths, MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray, const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment