Commit 3b5c6c7f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 3eb036b2
...@@ -181,10 +181,8 @@ void miopen_gemm::fill_result(const shape& output_shape, ...@@ -181,10 +181,8 @@ void miopen_gemm::fill_result(const shape& output_shape,
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpy(to_pointer(result), hipMemcpy(
to_pointer(c), to_pointer(result), to_pointer(c), output_shape.bytes(), hipMemcpyDeviceToDevice);
output_shape.bytes(),
hipMemcpyDeviceToDevice);
}); });
} }
else if(c.single()) else if(c.single())
...@@ -320,8 +318,8 @@ argument miopen_gemm::compute(context& ctx, ...@@ -320,8 +318,8 @@ argument miopen_gemm::compute(context& ctx,
std::size_t batch_num = std::accumulate( std::size_t batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
...@@ -338,7 +336,8 @@ argument miopen_gemm::compute(context& ctx, ...@@ -338,7 +336,8 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1]), to_pointer(args[1]),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2], batch_no * n), 1); to_pointer(args[2], batch_no * n),
1);
} }
}); });
} }
...@@ -358,8 +357,8 @@ argument miopen_gemm::compute(context& ctx, ...@@ -358,8 +357,8 @@ argument miopen_gemm::compute(context& ctx,
std::size_t batch_num = std::accumulate( std::size_t batch_num = std::accumulate(
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
...@@ -376,7 +375,8 @@ argument miopen_gemm::compute(context& ctx, ...@@ -376,7 +375,8 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1], batch_no * m * n), to_pointer(args[1], batch_no * m * n),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2], batch_no * m), 1); to_pointer(args[2], batch_no * m),
1);
} }
}); });
} }
...@@ -397,7 +397,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -397,7 +397,7 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[out_lens.size() - 2]; rocblas_int m = out_lens[out_lens.size() - 2];
rocblas_int n = out_lens[out_lens.size() - 1]; rocblas_int n = out_lens[out_lens.size() - 1];
rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1]; rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1];
float beta = 0.0f; float beta = 0.0f;
auto input_dims = std::min(a_lens.size(), b_lens.size()); auto input_dims = std::min(a_lens.size(), b_lens.size());
std::size_t axis{0}; std::size_t axis{0};
for(axis = 2; axis < input_dims; ++axis) for(axis = 2; axis < input_dims; ++axis)
...@@ -449,8 +449,8 @@ argument miopen_gemm::compute(context& ctx, ...@@ -449,8 +449,8 @@ argument miopen_gemm::compute(context& ctx,
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end()); std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
......
...@@ -20,9 +20,7 @@ struct miopen_gemm ...@@ -20,9 +20,7 @@ struct miopen_gemm
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private: private:
void fill_result(const shape& output_shape, void fill_result(const shape& output_shape, const argument& result, const argument& c) const;
const argument& result,
const argument& c) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment