gather_mm.h 3.38 KB
Newer Older
1
/**
Israt Nisa's avatar
Israt Nisa committed
2
 *  Copyright (c) 2022 by Contributors
3
4
 * @file array/cpu/gather_mm.h
 * @brief GATHER_MM CPU kernel function header.
Israt Nisa's avatar
Israt Nisa committed
5
6
7
8
9
10
 */
#ifndef DGL_ARRAY_CPU_GATHER_MM_H_
#define DGL_ARRAY_CPU_GATHER_MM_H_

#include <dgl/array.h>
#include <dgl/bcast.h>
11

Israt Nisa's avatar
Israt Nisa committed
12
13
14
15
16
17
18
19
20
#include <utility>

namespace dgl {
namespace aten {
namespace cpu {

template <typename DType>
void transpose(const DType *in, DType *out, const int N, const int M) {
#pragma omp parallel for
21
22
23
24
25
  for (int n = 0; n < N * M; n++) {
    int i = n / N;
    int j = n % N;
    out[n] = in[M * j + i];
  }
Israt Nisa's avatar
Israt Nisa committed
26
27
28
}

template <typename DType>
29
30
31
void matmul(
    const DType *A, const DType *B, DType *C, const int M, const int N,
    const int K) {
Israt Nisa's avatar
Israt Nisa committed
32
#pragma omp parallel
33
34
  {
    int i, j, k;
Israt Nisa's avatar
Israt Nisa committed
35
#pragma omp for
36
37
38
39
40
    for (i = 0; i < M; i++) {
      for (j = 0; j < N; j++) {
        DType local_accum = 0;
        for (k = 0; k < K; k++) {
          local_accum += A[i * K + k] * B[k * N + j];
Israt Nisa's avatar
Israt Nisa committed
41
        }
42
43
        C[i * N + j] = local_accum;
      }
Israt Nisa's avatar
Israt Nisa committed
44
    }
45
  }
Israt Nisa's avatar
Israt Nisa committed
46
47
}

48
/**
49
 * @brief CPU kernel of Gather_mm. The input matrix A is expected to be
Israt Nisa's avatar
Israt Nisa committed
50
 *        sorted according to relation type.
51
52
53
54
55
56
57
 * @param A The input dense matrix of dimension m x k
 * @param B The input dense matrix of dimension k x n
 * @param C The output dense matrix od dimension m x n
 * @param A_dim1_per_rel The number of rows in each relation in A
 * @param B_dim1_per_rel The number of rows in each relation in B
 * @param a_trans Matrix A to be transposed
 * @param b_trans Matrix B to be transposed
Israt Nisa's avatar
Israt Nisa committed
58
59
 */
template <int XPU, typename IdType, typename DType>
60
61
62
63
64
65
66
67
68
69
void gatherMM_SortedEtype(
    const NDArray A, const NDArray B, NDArray C, const NDArray A_dim1_per_rel,
    const NDArray B_dim1_per_rel, bool a_trans, bool b_trans) {
  assert(A_dim1_per_rel.NumElements() == B_dim1_per_rel.NumElements());
  int64_t num_rel = A_dim1_per_rel.NumElements();
  const DType *A_data = A.Ptr<DType>();
  const DType *B_data = B.Ptr<DType>();
  const IdType *A_rel_data = A_dim1_per_rel.Ptr<IdType>();
  const IdType *B_rel_data = B_dim1_per_rel.Ptr<IdType>();
  DType *C_data = C.Ptr<DType>();
Israt Nisa's avatar
Israt Nisa committed
70

71
72
73
74
75
76
77
78
79
80
  int64_t A_offset = 0, B_offset = 0, C_offset = 0;
  int64_t m, n, k, h_col, w_row;
  for (int etype = 0; etype < num_rel; ++etype) {
    assert(
        (a_trans)                  ? A_rel_data[etype]
        : A->shape[1] == (b_trans) ? B->shape[1]
                                   : B_rel_data[etype]);
    m = A_rel_data[etype];  // rows of A
    n = B->shape[1];        // cols of B
    k = B_rel_data[etype];  // rows of B == cols of A
Israt Nisa's avatar
Israt Nisa committed
81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    NDArray A_trans, B_trans;
    if (a_trans) {
      A_trans = NDArray::Empty({m * k}, A->dtype, A->ctx);
      transpose<DType>(
          A_data + A_offset, static_cast<DType *>(A_trans->data), m, k);
    }
    if (b_trans) {
      B_trans = NDArray::Empty({k * n}, B->dtype, B->ctx);
      transpose<DType>(
          B_data + B_offset, static_cast<DType *>(B_trans->data), k, n);
    }
    if (a_trans || b_trans) {
      int64_t tmp = k;
      if (a_trans) std::swap(m, k);
      if (b_trans) {
        k = tmp;
        std::swap(n, k);
      }
Israt Nisa's avatar
Israt Nisa committed
100
    }
101
102
103
104
105
106
107
108
    matmul<DType>(
        (a_trans) ? static_cast<DType *>(A_trans->data) : A_data + A_offset,
        (b_trans) ? static_cast<DType *>(B_trans->data) : B_data + B_offset,
        C_data + C_offset, m, n, k);
    A_offset += m * k;
    B_offset += k * n;
    C_offset += m * n;
  }
Israt Nisa's avatar
Israt Nisa committed
109
110
111
112
113
114
115
}

}  // namespace cpu
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_GATHER_MM_H_