Commit 9c16a90e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent c43eba64
...@@ -844,19 +844,15 @@ struct dot ...@@ -844,19 +844,15 @@ struct dot
auto t = a.type(); auto t = a.type();
// change to support cases like {1, 1, 3, 5} X {1, 1, 5, 6}, // change to support cases like {1, 1, 3, 5} X {1, 1, 5, 6},
// which can be handled by numpy. as long as all previous // which can be handled by numpy. as long as all previous
// dims are 1 except the last two dims, the two matrices // dims are 1 except the last two dims, the two matrices
// are multipliable // are multipliable
if (std::any_of(a.lens().rbegin() + 2, a.lens().rend(), [](auto i) { if(std::any_of(a.lens().rbegin() + 2, a.lens().rend(), [](auto i) { return (i != 1); }))
return (i != 1);
}))
{ {
MIGRAPHX_THROW("DOT: first matrix, dimensions before matrix dims must be 1"); MIGRAPHX_THROW("DOT: first matrix, dimensions before matrix dims must be 1");
} }
if (std::any_of(b.lens().rbegin() + 2, b.lens().rend(), [](auto i) { if(std::any_of(b.lens().rbegin() + 2, b.lens().rend(), [](auto i) { return (i != 1); }))
return (i != 1);
}))
{ {
MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1"); MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1");
} }
...@@ -865,7 +861,7 @@ struct dot ...@@ -865,7 +861,7 @@ struct dot
if(a.lens()[n_dims - 1] != b.lens()[n_dims - 2]) if(a.lens()[n_dims - 1] != b.lens()[n_dims - 2])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}"); "} x {" + to_string_range(b.lens()) + "}");
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[n_dims - 1] = b.lens()[n_dims - 1]; out_lens[n_dims - 1] = b.lens()[n_dims - 1];
return {t, out_lens}; return {t, out_lens};
} }
......
...@@ -14,10 +14,10 @@ template <class T> ...@@ -14,10 +14,10 @@ template <class T>
static auto make_mat(tensor_view<T> x) static auto make_mat(tensor_view<T> x)
{ {
const auto& s = x.get_shape(); const auto& s = x.get_shape();
//assert(s.lens().size() == 2); // assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size(); std::size_t n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1; std::size_t dim_1 = n_dims - 1;
if(s.transposed()) if(s.transposed())
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]}; return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]}; return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
...@@ -68,11 +68,11 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -68,11 +68,11 @@ void migemm_impl(tensor_view<T> cmat,
std::false_type) std::false_type)
{ {
std::size_t n_dims = cmat.get_shape().lens().size(); std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1; std::size_t dim_1 = n_dims - 1;
auto m = cmat.get_shape().lens()[dim_0]; auto m = cmat.get_shape().lens()[dim_0];
auto n = cmat.get_shape().lens()[dim_1]; auto n = cmat.get_shape().lens()[dim_1];
auto k = amat.get_shape().lens()[dim_1]; auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(m == amat.get_shape().lens()[dim_0]); assert(m == amat.get_shape().lens()[dim_0]);
......
...@@ -76,19 +76,19 @@ argument miopen_gemm::compute(context& ctx, ...@@ -76,19 +76,19 @@ argument miopen_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
std::size_t n_dims = args[0].get_shape().lens().size(); std::size_t n_dims = args[0].get_shape().lens().size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1; std::size_t dim_1 = n_dims - 1;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int m = output_shape.lens()[dim_0]; rocblas_int m = output_shape.lens()[dim_0];
rocblas_int n = output_shape.lens()[dim_1]; rocblas_int n = output_shape.lens()[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha)); auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment