gemm.cpp 1.66 KB
Newer Older
wsttiger's avatar
wsttiger committed
1
2
3
4
5
6
7
#include <migraph/gpu/gemm.hpp>
#include <migraph/operators.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/gpu/miopen.hpp>
#include <utility>

namespace migraph {
8
inline namespace MIGRAPH_INLINE_NS {
wsttiger's avatar
wsttiger committed
9
10
11
12
13
14
15
namespace gpu {

shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
    check_shapes{inputs, *this}.has(3);
    return op.compute_shape({inputs.at(0), inputs.at(1)});
}
wsttiger's avatar
wsttiger committed
16
17
18
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
19
20
21
22
23
24
25
26
27
28
29
{
    float alpha     = 1.0f;
    float beta      = 0.0f;
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
    rocblas_int ldc = args[2].get_shape().strides()[0];
    rocblas_int m   = output_shape.lens()[0];
    rocblas_int n   = output_shape.lens()[1];
    rocblas_int k   = args[0].get_shape().lens()[1];
Paul's avatar
Paul committed
30
    rocblas_sgemm(ctx.get_stream().get_rocblas(),
wsttiger's avatar
wsttiger committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
                  transb ? rocblas_operation_transpose : rocblas_operation_none,
                  transa ? rocblas_operation_transpose : rocblas_operation_none,
                  n,
                  m,
                  k,
                  &alpha,
                  args[1].implicit(),
                  ldb,
                  args[0].implicit(),
                  lda,
                  &beta,
                  args[2].implicit(),
                  ldc);
    return args[2];
}

} // namespace gpu
48
} // inline namespace MIGRAPH_INLINE_NS
wsttiger's avatar
wsttiger committed
49
} // namespace migraph