Commit b9e0366d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 7d986afb
......@@ -171,10 +171,10 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape(inputs);
}
std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index, std::vector<std::size_t> &data_lens) const
std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index,
std::vector<std::size_t>& data_lens) const
{
}
argument miopen_gemm::compute(context& ctx,
......@@ -183,50 +183,53 @@ argument miopen_gemm::compute(context& ctx,
{
bool is_3inputs = (args.size() == 4);
if (output_shape.elements() == 1)
if(output_shape.elements() == 1)
{
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_dot(as, ctx.get_stream().get_rocblas(),
args[1].get_shape().elements(),
generic_rocblas_dot(as,
ctx.get_stream().get_rocblas(),
args[1].get_shape().elements(),
to_pointer(args[0]),
1,
to_pointer(args[1]),
1,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2]));
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
generic_rocblas_scal(as, ctx.get_stream().get_rocblas(),
1,
&alpha_r,
is_3inputs ? to_pointer(args[3]): to_pointer(args[2]));
generic_rocblas_scal(as,
ctx.get_stream().get_rocblas(),
1,
&alpha_r,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
1);
if (is_3inputs)
{
generic_rocblas_axpy(as, ctx.get_stream().get_rocblas(),
1,
&beta_r,
to_pointer(args[2]),
1,
to_pointer(args[3]),
1);
}
if(is_3inputs)
{
generic_rocblas_axpy(as,
ctx.get_stream().get_rocblas(),
1,
&beta_r,
to_pointer(args[2]),
1,
to_pointer(args[3]),
1);
}
});
return is_3inputs ? args[3] : args[2];
}
// b is a vector, so the computation is matrix * vector
// could not be the case of inner product of vectors since
// could not be the case of inner product of vectors since
// it is already processed above
if (args[1].get_shape().lens().size() == 1)
if(args[1].get_shape().lens().size() == 1)
{
// considering the batch input, so A could be a batch
// of matrices
auto a_lens = args[0].get_shape().lens();
auto a_lens = args[0].get_shape().lens();
std::size_t n_dims = a_lens.size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
......@@ -236,18 +239,20 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int k = a_lens[dim_1];
auto batch_num = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg, std::size_t offset) { return to_rocblas_type(as.from(arg.data() + offset)); };
for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
auto to_pointer = [&](auto&& arg, std::size_t offset) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
{
if(is_3inputs)
hipMemcpy(to_pointer(args[3] + batch_no * m),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
else
hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
}
......@@ -271,12 +276,14 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4);
// two input arguments
if (!is_3inputs)
if(!is_3inputs)
{
}
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
if(is_3inputs)
hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]),
......@@ -287,7 +294,9 @@ argument miopen_gemm::compute(context& ctx,
});
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
......
......@@ -19,8 +19,10 @@ struct miopen_gemm
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private:
std::size_t compute_offset(std::vector<std::size_t>& out_lens, std::size_t index, std::vector<std::size_t> &data_lens) const;
private:
std::size_t compute_offset(std::vector<std::size_t>& out_lens,
std::size_t index,
std::vector<std::size_t>& data_lens) const;
};
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment