Commit b9e0366d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 7d986afb
...@@ -172,9 +172,9 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -172,9 +172,9 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
} }
std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens, std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index, std::vector<std::size_t> &data_lens) const std::size_t index,
std::vector<std::size_t>& data_lens) const
{ {
} }
argument miopen_gemm::compute(context& ctx, argument miopen_gemm::compute(context& ctx,
...@@ -183,30 +183,33 @@ argument miopen_gemm::compute(context& ctx, ...@@ -183,30 +183,33 @@ argument miopen_gemm::compute(context& ctx,
{ {
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
if (output_shape.elements() == 1) if(output_shape.elements() == 1)
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_dot(as, ctx.get_stream().get_rocblas(), generic_rocblas_dot(as,
ctx.get_stream().get_rocblas(),
args[1].get_shape().elements(), args[1].get_shape().elements(),
to_pointer(args[0]), to_pointer(args[0]),
1, 1,
to_pointer(args[1]), to_pointer(args[1]),
1, 1,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2])); is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
generic_rocblas_scal(as, ctx.get_stream().get_rocblas(), generic_rocblas_scal(as,
ctx.get_stream().get_rocblas(),
1, 1,
&alpha_r, &alpha_r,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2])); is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
1); 1);
if (is_3inputs) if(is_3inputs)
{ {
generic_rocblas_axpy(as, ctx.get_stream().get_rocblas(), generic_rocblas_axpy(as,
ctx.get_stream().get_rocblas(),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
...@@ -222,7 +225,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -222,7 +225,7 @@ argument miopen_gemm::compute(context& ctx,
// b is a vector, so the computation is matrix * vector // b is a vector, so the computation is matrix * vector
// could not be the case of inner product of vectors since // could not be the case of inner product of vectors since
// it is already processed above // it is already processed above
if (args[1].get_shape().lens().size() == 1) if(args[1].get_shape().lens().size() == 1)
{ {
// considering the batch input, so A could be a batch // considering the batch input, so A could be a batch
// of matrices // of matrices
...@@ -240,8 +243,10 @@ argument miopen_gemm::compute(context& ctx, ...@@ -240,8 +243,10 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg, std::size_t offset) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset) {
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no) return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{ {
if(is_3inputs) if(is_3inputs)
hipMemcpy(to_pointer(args[3] + batch_no * m), hipMemcpy(to_pointer(args[3] + batch_no * m),
...@@ -271,12 +276,14 @@ argument miopen_gemm::compute(context& ctx, ...@@ -271,12 +276,14 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
// two input arguments // two input arguments
if (!is_3inputs) if(!is_3inputs)
{ {
} }
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
if(is_3inputs) if(is_3inputs)
hipMemcpy(to_pointer(args[3]), hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]), to_pointer(args[2]),
...@@ -287,7 +294,9 @@ argument miopen_gemm::compute(context& ctx, ...@@ -287,7 +294,9 @@ argument miopen_gemm::compute(context& ctx,
}); });
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_batched_gemm(as, generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
......
...@@ -19,8 +19,10 @@ struct miopen_gemm ...@@ -19,8 +19,10 @@ struct miopen_gemm
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private: private:
std::size_t compute_offset(std::vector<std::size_t>& out_lens, std::size_t index, std::vector<std::size_t> &data_lens) const; std::size_t compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index,
std::vector<std::size_t>& data_lens) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment