"doc/vscode:/vscode.git/clone" did not exist on "b25ad4ec32ab6f7d5b9f5674d09aa01149f6c058"
gemm.cpp 1.58 KB
Newer Older
wsttiger's avatar
wsttiger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraph/gpu/gemm.hpp>
#include <migraph/operators.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/gpu/miopen.hpp>
#include <utility>

namespace migraph {
namespace gpu {

shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
    check_shapes{inputs, *this}.has(3);
    return op.compute_shape({inputs.at(0), inputs.at(1)});
}
wsttiger's avatar
wsttiger committed
15
16
17
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
18
19
20
21
22
23
24
25
26
27
28
{
    float alpha     = 1.0f;
    float beta      = 0.0f;
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
    rocblas_int ldc = args[2].get_shape().strides()[0];
    rocblas_int m   = output_shape.lens()[0];
    rocblas_int n   = output_shape.lens()[1];
    rocblas_int k   = args[0].get_shape().lens()[1];
Paul's avatar
Paul committed
29
    rocblas_sgemm(ctx.get_stream().get_rocblas(),
wsttiger's avatar
wsttiger committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
                  transb ? rocblas_operation_transpose : rocblas_operation_none,
                  transa ? rocblas_operation_transpose : rocblas_operation_none,
                  n,
                  m,
                  k,
                  &alpha,
                  args[1].implicit(),
                  ldb,
                  args[0].implicit(),
                  lda,
                  &beta,
                  args[2].implicit(),
                  ldc);
    return args[2];
}

} // namespace gpu

} // namespace migraph