Commit 766c600f authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Update parse_splitToSequence

Per the spec. Perform split based on splits. If Split is a scalar use the value
in the split and split equally.

If not scalar, split along desired dimensions. Check dims accordingly so they
sum to the length of the target axes

For keepdims, only split down to new sequences if keepdims=0 otherwise use the
axes length as the output of keepdims=1
parent d260f0e8
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/onnx/checks.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,46 +42,76 @@ struct parse_splitToSequence : op_parser<parse_splitToSequence> ...@@ -41,46 +42,76 @@ struct parse_splitToSequence : op_parser<parse_splitToSequence>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
int64_t axis = 0; int64_t axis = 0;
int64_t keep_dims = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
{ {
axis = parser.parse_value(info.attributes.at("axis")).at<int>(); axis = parser.parse_value(info.attributes.at("axis")).at<int>();
} }
if(contains(info.attributes, "keepdims"))
{
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
}
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
int64_t n_rank = lens.size(); int64_t n_rank = lens.size();
int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name); int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
std::vector<int64_t> vec_splits; std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split")) if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "SplitToSequence: dynamic shape is not supported");
const auto split_shape = s.get_shape();
// check all split args > 1
for(const auto split_arg : split_shape.lens())
{
assert(split_arg > 0);
}
if(split_shape.scalar())
{
// Split equally along one axis based on desired split
auto split_output = lens[tuned_axis] / split_shape.lens().at(0);
auto dl = lens[tuned_axis] / info.num_outputs;
vec_splits.resize(split_output, dl);
}
else
{ {
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); }); s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) != if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis])) static_cast<int64_t>(lens[tuned_axis]))
{ {
MIGRAPHX_THROW("PARSE_SPLIT_TO_SEQ: sum of split attribute unequal to dim size of axis!"); MIGRAPHX_THROW("PARSE_SPLIT_TO_SEQ: sum of split attribute unequal to dim size of "
"axis! Axis " +
std::to_string(lens[tuned_axis]) + " Split " +
to_string_range(vec_splits));
} }
} }
// no split attribute, use other paremters to determine splits
else else
{ {
int64_t keepdims = 1; if(keep_dims == 0)
if(contains(info.attributes, "keepdims"))
{ {
keepdims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
}
if((lens[tuned_axis] % info.num_outputs) != 0) if((lens[tuned_axis] % info.num_outputs) != 0)
{ {
MIGRAPHX_THROW("PARSE_SPLIT_TO_SEQ: input cannot be equally divided into " + MIGRAPHX_THROW("PARSE_SPLIT_TO_SEQ: input cannot be equally divided into " +
std::to_string(info.num_outputs) + " splits!"); std::to_string(info.num_outputs) + " splits!");
} }
auto dl = lens[tuned_axis] / info.num_outputs; auto dl = lens[tuned_axis] / info.num_outputs;
vec_splits.resize(info.num_outputs, dl); vec_splits.resize(info.num_outputs, dl);
} }
else
{
vec_splits.resize(info.num_outputs, lens[tuned_axis]);
}
}
std::vector<instruction_ref> ret_ins; std::vector<instruction_ref> ret_ins;
int64_t start = 0; int64_t start = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment