"src/vscode:/vscode.git/clone" did not exist on "9a63376e550ee25dbe7473657942516e49e45850"
Commit dc945bf8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 1ff459b4
......@@ -584,14 +584,14 @@ struct squeeze
}
// squeezing a single element generates a scalar
if (new_lens.empty())
if(new_lens.empty())
{
return {type};
}
else
{
return shape{type, new_lens};
}
return shape{type, new_lens};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
......
......@@ -154,7 +154,8 @@ struct onnx_parser
});
}
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> s1)
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
......@@ -168,15 +169,18 @@ struct onnx_parser
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if (s0.size() > s1.size())
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
std::transform(s0.begin(),
s0.end(),
s1.begin() + offset,
out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
return out_lens;
}
......@@ -187,11 +191,11 @@ struct onnx_parser
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
......@@ -499,11 +503,11 @@ struct onnx_parser
{
if(beta != 0.f && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if (!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
}
......@@ -517,14 +521,14 @@ struct onnx_parser
instruction_ref
parse_matmul(const std::string&, attribute_map, std::vector<instruction_ref> args)
{
auto l0 = args[0];
auto l1 = args[1];
auto l0 = args[0];
auto l1 = args[1];
auto l0_lens = l0->get_shape().lens();
auto l1_lens = l1->get_shape().lens();
// args[0] is a vector, prepend 1 to the shape
bool is_a_prepended = false;
if (l0_lens.size() == 1)
if(l0_lens.size() == 1)
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
......@@ -532,7 +536,7 @@ struct onnx_parser
}
bool is_b_appended = false;
if (l1_lens.size() == 1)
if(l1_lens.size() == 1)
{
is_b_appended = true;
l1_lens.push_back(1);
......@@ -541,7 +545,7 @@ struct onnx_parser
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if (!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
......@@ -550,28 +554,28 @@ struct onnx_parser
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if (l0_lens != l0_broadcasted_lens)
if(l0_lens != l0_broadcasted_lens)
{
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if (l1_lens != l1_broadcasted_lens)
if(l1_lens != l1_broadcasted_lens)
{
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
}
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if (is_a_prepended)
if(is_a_prepended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis;
}
if (is_b_appended)
if(is_b_appended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
}
return dot_res;
}
......
......@@ -566,7 +566,7 @@ TEST_CASE(gemm_test)
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto alpha = 2.f;
auto beta = 2.0f;
auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1);
auto prog = migraphx::parse_onnx("gemm_test.onnx");
......@@ -576,12 +576,12 @@ TEST_CASE(gemm_test)
TEST_CASE(gemm_ex)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f;
auto beta = 0.8f;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f;
auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment