Commit 9c16a90e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent c43eba64
......@@ -844,19 +844,15 @@ struct dot
auto t = a.type();
// change to support cases like {1, 1, 3, 5} X {1, 1, 5, 6},
// which can be handled by numpy. as long as all previous
// which can be handled by numpy. as long as all previous
// dims are 1 except the last two dims, the two matrices
// are multipliable
if (std::any_of(a.lens().rbegin() + 2, a.lens().rend(), [](auto i) {
return (i != 1);
}))
if(std::any_of(a.lens().rbegin() + 2, a.lens().rend(), [](auto i) { return (i != 1); }))
{
MIGRAPHX_THROW("DOT: first matrix, dimensions before matrix dims must be 1");
}
if (std::any_of(b.lens().rbegin() + 2, b.lens().rend(), [](auto i) {
return (i != 1);
}))
if(std::any_of(b.lens().rbegin() + 2, b.lens().rend(), [](auto i) { return (i != 1); }))
{
MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1");
}
......@@ -865,7 +861,7 @@ struct dot
if(a.lens()[n_dims - 1] != b.lens()[n_dims - 2])
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
auto out_lens = a.lens();
auto out_lens = a.lens();
out_lens[n_dims - 1] = b.lens()[n_dims - 1];
return {t, out_lens};
}
......
......@@ -14,10 +14,10 @@ template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
//assert(s.lens().size() == 2);
// assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
if(s.transposed())
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
......@@ -68,11 +68,11 @@ void migemm_impl(tensor_view<T> cmat,
std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto m = cmat.get_shape().lens()[dim_0];
auto n = cmat.get_shape().lens()[dim_1];
auto k = amat.get_shape().lens()[dim_1];
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto m = cmat.get_shape().lens()[dim_0];
auto n = cmat.get_shape().lens()[dim_1];
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(m == amat.get_shape().lens()[dim_0]);
......
......@@ -76,19 +76,19 @@ argument miopen_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1.0f;
float beta = 0.0f;
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
float alpha = 1.0f;
float beta = 0.0f;
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
std::size_t n_dims = args[0].get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int m = output_shape.lens()[dim_0];
rocblas_int n = output_shape.lens()[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int m = output_shape.lens()[dim_0];
rocblas_int n = output_shape.lens()[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment