Commit 1681e49a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 5b4bb22c
......@@ -834,7 +834,7 @@ struct dot
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B");
}
if (inputs.size() == 3)
if(inputs.size() == 3)
{
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
const shape& c = inputs.at(2);
......@@ -849,19 +849,21 @@ struct dot
if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
if (inputs.size() == 3)
if(inputs.size() == 3)
{
const shape& c = inputs.at(2);
if (a.lens()[dim_0] != c.lens()[dim_0])
if(a.lens()[dim_0] != c.lens()[dim_0])
{
MIGRAPHX_THROW("DOT : matrix size does not match: A: {" + to_string_range(a.lens()) +
"}, C: {" + to_string_range(c.lens()) + "}");
MIGRAPHX_THROW("DOT : matrix size does not match: A: {" +
to_string_range(a.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
}
if (b.lens()[dim_1] != c.lens()[dim_1])
if(b.lens()[dim_1] != c.lens()[dim_1])
{
MIGRAPHX_THROW("DOT : matrix size does not match: B: {" + to_string_range(b.lens()) +
"}, C: {" + to_string_range(c.lens()) + "}");
MIGRAPHX_THROW("DOT : matrix size does not match: B: {" +
to_string_range(b.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
}
}
......
......@@ -55,12 +55,12 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
if (beta != 0.0)
if(beta != 0.0)
{
c = beta * c;
}
if (alpha != 0.0)
if(alpha != 0.0)
{
c = c + alpha * a * b;
}
......
......@@ -375,21 +375,18 @@ struct cpu_gemm
{
argument result{output_shape};
// if there is a C input
if (args.size() == 3)
if(args.size() == 3)
{
result.visit([&](auto output) {
args[2].visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
args[2].visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
}
else
{
result.visit([&](auto output) {
std::fill(output.begin(), output.end(), 0);
});
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
......
......@@ -114,8 +114,11 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4);
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if (is_3inputs)
hipMemcpy(to_pointer(args[3]), to_pointer(args[2]), output_shape.bytes(), hipMemcpyDeviceToDevice);
if(is_3inputs)
hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
});
......
......@@ -1138,11 +1138,11 @@ TEST_CASE(gemm_mutli_3args)
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......@@ -1158,7 +1158,6 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(maxpool_test)
{
migraphx::program p;
......
......@@ -891,11 +891,11 @@ struct gemm_mutli_3args
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment