Commit 1681e49a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 5b4bb22c
...@@ -834,7 +834,7 @@ struct dot ...@@ -834,7 +834,7 @@ struct dot
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B"); MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B");
} }
if (inputs.size() == 3) if(inputs.size() == 3)
{ {
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type(); check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
const shape& c = inputs.at(2); const shape& c = inputs.at(2);
...@@ -849,19 +849,21 @@ struct dot ...@@ -849,19 +849,21 @@ struct dot
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
if (inputs.size() == 3) if(inputs.size() == 3)
{ {
const shape& c = inputs.at(2); const shape& c = inputs.at(2);
if (a.lens()[dim_0] != c.lens()[dim_0]) if(a.lens()[dim_0] != c.lens()[dim_0])
{ {
MIGRAPHX_THROW("DOT : matrix size does not match: A: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT : matrix size does not match: A: {" +
"}, C: {" + to_string_range(c.lens()) + "}"); to_string_range(a.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
} }
if (b.lens()[dim_1] != c.lens()[dim_1]) if(b.lens()[dim_1] != c.lens()[dim_1])
{ {
MIGRAPHX_THROW("DOT : matrix size does not match: B: {" + to_string_range(b.lens()) + MIGRAPHX_THROW("DOT : matrix size does not match: B: {" +
"}, C: {" + to_string_range(c.lens()) + "}"); to_string_range(b.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
} }
} }
......
...@@ -55,12 +55,12 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -55,12 +55,12 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat); auto c = make_mat(cmat);
if (beta != 0.0) if(beta != 0.0)
{ {
c = beta * c; c = beta * c;
} }
if (alpha != 0.0) if(alpha != 0.0)
{ {
c = c + alpha * a * b; c = c + alpha * a * b;
} }
......
...@@ -375,21 +375,18 @@ struct cpu_gemm ...@@ -375,21 +375,18 @@ struct cpu_gemm
{ {
argument result{output_shape}; argument result{output_shape};
// if there is a C input // if there is a C input
if (args.size() == 3) if(args.size() == 3)
{ {
result.visit([&](auto output) { result.visit([&](auto output) {
args[2].visit([&](auto input) { args[2].visit(
std::copy(input.begin(), input.end(), output.begin()); [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
}); });
} }
else else
{ {
result.visit([&](auto output) { result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
std::fill(output.begin(), output.end(), 0);
});
} }
migemm(result, args[0], args[1], op.alpha, op.beta); migemm(result, args[0], args[1], op.alpha, op.beta);
return result; return result;
} }
......
...@@ -114,8 +114,11 @@ argument miopen_gemm::compute(context& ctx, ...@@ -114,8 +114,11 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if (is_3inputs) if(is_3inputs)
hipMemcpy(to_pointer(args[3]), to_pointer(args[2]), output_shape.bytes(), hipMemcpyDeviceToDevice); hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
else else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes()); hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
}); });
......
...@@ -1138,11 +1138,11 @@ TEST_CASE(gemm_mutli_3args) ...@@ -1138,11 +1138,11 @@ TEST_CASE(gemm_mutli_3args)
0.49759611, 0.10021662, 0.00592602, 0.90862000}; 0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1}); auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2}); auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3}); auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -1158,7 +1158,6 @@ TEST_CASE(gemm_mutli_3args) ...@@ -1158,7 +1158,6 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -891,11 +891,11 @@ struct gemm_mutli_3args ...@@ -891,11 +891,11 @@ struct gemm_mutli_3args
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape); auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment