Commit 1681e49a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 5b4bb22c
...@@ -834,7 +834,7 @@ struct dot ...@@ -834,7 +834,7 @@ struct dot
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B"); MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B");
} }
if (inputs.size() == 3) if(inputs.size() == 3)
{ {
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type(); check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
const shape& c = inputs.at(2); const shape& c = inputs.at(2);
...@@ -849,19 +849,21 @@ struct dot ...@@ -849,19 +849,21 @@ struct dot
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
if (inputs.size() == 3) if(inputs.size() == 3)
{ {
const shape& c = inputs.at(2); const shape& c = inputs.at(2);
if (a.lens()[dim_0] != c.lens()[dim_0]) if(a.lens()[dim_0] != c.lens()[dim_0])
{ {
MIGRAPHX_THROW("DOT : matrix size does not match: A: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT : matrix size does not match: A: {" +
"}, C: {" + to_string_range(c.lens()) + "}"); to_string_range(a.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
} }
if (b.lens()[dim_1] != c.lens()[dim_1]) if(b.lens()[dim_1] != c.lens()[dim_1])
{ {
MIGRAPHX_THROW("DOT : matrix size does not match: B: {" + to_string_range(b.lens()) + MIGRAPHX_THROW("DOT : matrix size does not match: B: {" +
"}, C: {" + to_string_range(c.lens()) + "}"); to_string_range(b.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
} }
} }
......
...@@ -55,12 +55,12 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -55,12 +55,12 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat); auto c = make_mat(cmat);
if (beta != 0.0) if(beta != 0.0)
{ {
c = beta * c; c = beta * c;
} }
if (alpha != 0.0) if(alpha != 0.0)
{ {
c = c + alpha * a * b; c = c + alpha * a * b;
} }
......
...@@ -375,19 +375,16 @@ struct cpu_gemm ...@@ -375,19 +375,16 @@ struct cpu_gemm
{ {
argument result{output_shape}; argument result{output_shape};
// if there is a C input // if there is a C input
if (args.size() == 3) if(args.size() == 3)
{ {
result.visit([&](auto output) { result.visit([&](auto output) {
args[2].visit([&](auto input) { args[2].visit(
std::copy(input.begin(), input.end(), output.begin()); [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
}); });
} }
else else
{ {
result.visit([&](auto output) { result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
std::fill(output.begin(), output.end(), 0);
});
} }
migemm(result, args[0], args[1], op.alpha, op.beta); migemm(result, args[0], args[1], op.alpha, op.beta);
......
...@@ -114,8 +114,11 @@ argument miopen_gemm::compute(context& ctx, ...@@ -114,8 +114,11 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if (is_3inputs) if(is_3inputs)
hipMemcpy(to_pointer(args[3]), to_pointer(args[2]), output_shape.bytes(), hipMemcpyDeviceToDevice); hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
else else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes()); hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
}); });
......
...@@ -1158,7 +1158,6 @@ TEST_CASE(gemm_mutli_3args) ...@@ -1158,7 +1158,6 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment