Commit fbef5744 authored by Khalique's avatar Khalique
Browse files

add workaround for scalar

parent 71276f4d
...@@ -137,36 +137,21 @@ struct onnx_parser ...@@ -137,36 +137,21 @@ struct onnx_parser
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens(); const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens(); const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens();
bool swapped = false;
// Make sure s0 is the smaller size // Make sure s0 is the smaller size
if(s0->size() > s1->size()) if(s0->size() > s1->size())
{
std::swap(s0, s1); std::swap(s0, s1);
swapped = true;
}
std::vector<std::size_t> output_lens(s1->size()); std::vector<std::size_t> output_lens(s1->size());
// if (s0->size() == 0) auto offset = s1->size() - s0->size();
// { std::transform(s0->begin(),
// shape s = swapped ? args[0]->get_shape() : args[1]->get_shape(); s0->end(),
// auto l0 = prog.add_instruction(migraphx::op::scalar{s}, 1.0f); s1->begin() + offset,
// return prog.add_instruction(x, l0, args[1]); output_lens.begin() + offset,
// } [](auto a, auto b) { return std::max(a, b); });
// else
// { auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
// Copy the larger vector to output_lens auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
auto offset = s1->size() - s0->size(); return prog.add_instruction(x, l0, l1);
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
// }
} }
else else
{ {
...@@ -602,6 +587,11 @@ struct onnx_parser ...@@ -602,6 +587,11 @@ struct onnx_parser
static literal parse_tensor(const onnx::TensorProto& t) static literal parse_tensor(const onnx::TensorProto& t)
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.size() == 0)
{
dims = {1};
}
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment