Commit b9e0366d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 7d986afb
...@@ -171,10 +171,10 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -171,10 +171,10 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens, std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index, std::vector<std::size_t> &data_lens) const std::size_t index,
std::vector<std::size_t>& data_lens) const
{ {
} }
argument miopen_gemm::compute(context& ctx, argument miopen_gemm::compute(context& ctx,
...@@ -183,50 +183,53 @@ argument miopen_gemm::compute(context& ctx, ...@@ -183,50 +183,53 @@ argument miopen_gemm::compute(context& ctx,
{ {
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
if (output_shape.elements() == 1) if(output_shape.elements() == 1)
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_dot(as, ctx.get_stream().get_rocblas(), generic_rocblas_dot(as,
args[1].get_shape().elements(), ctx.get_stream().get_rocblas(),
args[1].get_shape().elements(),
to_pointer(args[0]), to_pointer(args[0]),
1, 1,
to_pointer(args[1]), to_pointer(args[1]),
1, 1,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2])); is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
generic_rocblas_scal(as, ctx.get_stream().get_rocblas(), generic_rocblas_scal(as,
1, ctx.get_stream().get_rocblas(),
&alpha_r, 1,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2])); &alpha_r,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
1); 1);
if (is_3inputs) if(is_3inputs)
{ {
generic_rocblas_axpy(as, ctx.get_stream().get_rocblas(), generic_rocblas_axpy(as,
1, ctx.get_stream().get_rocblas(),
&beta_r, 1,
to_pointer(args[2]), &beta_r,
1, to_pointer(args[2]),
to_pointer(args[3]), 1,
1); to_pointer(args[3]),
} 1);
}
}); });
return is_3inputs ? args[3] : args[2]; return is_3inputs ? args[3] : args[2];
} }
// b is a vector, so the computation is matrix * vector // b is a vector, so the computation is matrix * vector
// could not be the case of inner product of vectors since // could not be the case of inner product of vectors since
// it is already processed above // it is already processed above
if (args[1].get_shape().lens().size() == 1) if(args[1].get_shape().lens().size() == 1)
{ {
// considering the batch input, so A could be a batch // considering the batch input, so A could be a batch
// of matrices // of matrices
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
std::size_t n_dims = a_lens.size(); std::size_t n_dims = a_lens.size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1; std::size_t dim_1 = n_dims - 1;
...@@ -236,18 +239,20 @@ argument miopen_gemm::compute(context& ctx, ...@@ -236,18 +239,20 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int k = a_lens[dim_1]; rocblas_int k = a_lens[dim_1];
auto batch_num = std::accumulate( auto batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg, std::size_t offset) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset) {
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no) return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{ {
if(is_3inputs) if(is_3inputs)
hipMemcpy(to_pointer(args[3] + batch_no * m), hipMemcpy(to_pointer(args[3] + batch_no * m),
to_pointer(args[2]), to_pointer(args[2]),
output_shape.bytes(), output_shape.bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
else else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes()); hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
} }
...@@ -271,12 +276,14 @@ argument miopen_gemm::compute(context& ctx, ...@@ -271,12 +276,14 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
// two input arguments // two input arguments
if (!is_3inputs) if(!is_3inputs)
{ {
} }
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
if(is_3inputs) if(is_3inputs)
hipMemcpy(to_pointer(args[3]), hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]), to_pointer(args[2]),
...@@ -287,7 +294,9 @@ argument miopen_gemm::compute(context& ctx, ...@@ -287,7 +294,9 @@ argument miopen_gemm::compute(context& ctx,
}); });
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); }; auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_batched_gemm(as, generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
......
...@@ -19,8 +19,10 @@ struct miopen_gemm ...@@ -19,8 +19,10 @@ struct miopen_gemm
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private: private:
std::size_t compute_offset(std::vector<std::size_t>& out_lens, std::size_t index, std::vector<std::size_t> &data_lens) const; std::size_t compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index,
std::vector<std::size_t>& data_lens) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment