Commit 1ff459b4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the onnx parser for MatMul

parent f140f4c6
...@@ -582,8 +582,17 @@ struct squeeze ...@@ -582,8 +582,17 @@ struct squeeze
} }
} }
} }
// squeezing a single element generates a scalar
if (new_lens.empty())
{
return {type};
}
else
{
return shape{type, new_lens}; return shape{type, new_lens};
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
...@@ -831,18 +840,17 @@ struct dot ...@@ -831,18 +840,17 @@ struct dot
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type(); check_shapes{inputs, *this}.same_type();
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
// according to the specification of the numpy.matmul() // only handle the case that the batch size of a and b are the same
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if(!std::equal( if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: dim values mismatch"); MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.lens().size() - 2;
......
...@@ -36,7 +36,6 @@ struct onnx_parser ...@@ -36,7 +36,6 @@ struct onnx_parser
onnx_parser() onnx_parser()
{ {
add_generic_op("MatMul", op::dot{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Sigmoid", op::sigmoid{}); add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{}); add_generic_op("Abs", op::abs{});
...@@ -77,6 +76,7 @@ struct onnx_parser ...@@ -77,6 +76,7 @@ struct onnx_parser
add_mem_op("Reshape", &onnx_parser::parse_reshape); add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax); add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax); add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
...@@ -154,10 +154,7 @@ struct onnx_parser ...@@ -154,10 +154,7 @@ struct onnx_parser
}); });
} }
template <class T> std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> s1)
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{ {
// Example: // Example:
// s0 = (3,2,4,5) and s1 = (2,1,1) // s0 = (3,2,4,5) and s1 = (2,1,1)
...@@ -171,25 +168,30 @@ struct onnx_parser ...@@ -171,25 +168,30 @@ struct onnx_parser
// In this case we need to broadcast the (:,:,1:,:) axis // In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving // of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
// if (s0.size() > s1.size())
// Get lengths for both arguments {
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens(); s0.swap(s1);
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens(); }
// Make sure s0 is the smaller size std::vector<std::size_t> out_lens(s1);
if(s0->size() > s1->size()) auto offset = s1.size() - s0.size();
std::swap(s0, s1); std::transform(s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset,
std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); return out_lens;
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); }
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1); return prog.add_instruction(x, l0, l1);
} }
else else
...@@ -495,25 +497,84 @@ struct onnx_parser ...@@ -495,25 +497,84 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f) if(beta != 0.f && args[2]->get_shape().elements() > 0)
{ {
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2); auto out_lens = l1->get_shape().lens();
auto l4 = args[2]; out_lens.back() = l2->get_shape().lens().back();
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B) auto l3 = args[2];
return l3; auto l3_lens = l3->get_shape().lens();
if(beta != 1.f) if (!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
auto beta_val = prog.add_literal(beta); l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
} }
return add_broadcastable_binary_op(l3, l4, op::add{}); return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
} }
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
instruction_ref
parse_matmul(const std::string&, attribute_map, std::vector<instruction_ref> args)
{
auto l0 = args[0];
auto l1 = args[1];
auto l0_lens = l0->get_shape().lens();
auto l1_lens = l1->get_shape().lens();
// args[0] is a vector, prepend 1 to the shape
bool is_a_prepended = false;
if (l0_lens.size() == 1)
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
}
bool is_b_appended = false;
if (l1_lens.size() == 1)
{
is_b_appended = true;
l1_lens.push_back(1);
l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
}
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if (!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if (l0_lens != l0_broadcasted_lens)
{
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if (l1_lens != l1_broadcasted_lens)
{
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
}
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if (is_a_prepended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis;
}
if (is_b_appended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
}
return dot_res;
}
instruction_ref instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
...@@ -566,7 +566,8 @@ TEST_CASE(gemm_test) ...@@ -566,7 +566,8 @@ TEST_CASE(gemm_test)
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto alpha = 2.f; auto alpha = 2.f;
p.add_instruction(migraphx::op::dot{alpha}, t0, t1); auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1);
auto prog = migraphx::parse_onnx("gemm_test.onnx"); auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -580,14 +581,8 @@ TEST_CASE(gemm_ex) ...@@ -580,14 +581,8 @@ TEST_CASE(gemm_ex)
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f; auto alpha = 0.5f;
auto res_ab = p.add_instruction(migraphx::op::dot{alpha}, t0, l1);
auto beta = 0.8f; auto beta = 0.8f;
auto l_beta = p.add_literal(beta); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto brcst_beta = p.add_instruction(migraphx::op::scalar{l2->get_shape()}, l_beta);
auto res_c = p.add_instruction(migraphx::op::mul{}, l2, brcst_beta);
p.add_instruction(migraphx::op::add{}, res_ab, res_c);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx"); auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment