Commit 1e5f7133 authored by Paul's avatar Paul
Browse files

Format

parent 84d5e2b6
...@@ -30,17 +30,17 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -30,17 +30,17 @@ rocblas_datatype get_type(shape::type_t type)
void blas_shape(const shape& s) void blas_shape(const shape& s)
{ {
if (s.lens().size() < 2) if(s.lens().size() < 2)
return; return;
if (std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; }))
return i == 1;
}))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1"); MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if (s.lens().size() < 3) if(s.lens().size() < 3)
return; return;
shape batch_shape{s.type(), {s.lens().begin(), s.lens().end() - 2}, {s.strides().begin(), s.strides().end() - 2}}; shape batch_shape{s.type(),
{s.lens().begin(), s.lens().end() - 2},
{s.strides().begin(), s.strides().end() - 2}};
auto batch_shapes = reduce_dims({batch_shape}); auto batch_shapes = reduce_dims({batch_shape});
if (batch_shapes.front().lens().size() != 1) if(batch_shapes.front().lens().size() != 1)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
} }
...@@ -55,7 +55,7 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs) ...@@ -55,7 +55,7 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
static bool is_transposed(const shape& s) static bool is_transposed(const shape& s)
{ {
if (not s.transposed()) if(not s.transposed())
return false; return false;
return s.strides().back() != 1; return s.strides().back() != 1;
} }
......
...@@ -113,7 +113,7 @@ TEST_CASE(non_packed_empty1) ...@@ -113,7 +113,7 @@ TEST_CASE(non_packed_empty1)
{ {
std::vector<migraphx::shape> ishapes = {make_shape({1, 12}, {589824, 64})}; std::vector<migraphx::shape> ishapes = {make_shape({1, 12}, {589824, 64})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})}; std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
...@@ -121,7 +121,7 @@ TEST_CASE(non_packed_empty2) ...@@ -121,7 +121,7 @@ TEST_CASE(non_packed_empty2)
{ {
std::vector<migraphx::shape> ishapes = {make_shape({12, 1}, {64, 589824})}; std::vector<migraphx::shape> ishapes = {make_shape({12, 1}, {64, 589824})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})}; std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes); auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment