csr_mm.cc 4.53 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cpu/csr_mm.cc
 * \brief CSR Matrix Multiplication
 */

#include <dgl/array.h>
8
#include <dgl/runtime/parallel_for.h>
9
#include <parallel_hashmap/phmap.h>
10

11
#include <vector>
12

13
14
15
16
17
#include "array_utils.h"

namespace dgl {

using dgl::runtime::NDArray;
18
using dgl::runtime::parallel_for;
19
20
21
22
23
24
25
26

namespace aten {

namespace {

// TODO(BarclayII): avoid using map for sorted CSRs
template <typename IdType>
void CountNNZPerRow(
27
28
    const IdType* A_indptr, const IdType* A_indices, const IdType* B_indptr,
    const IdType* B_indices, IdType* C_indptr_data, int64_t M) {
29
30
31
32
33
34
35
36
37
  parallel_for(0, M, [=](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
      phmap::flat_hash_set<IdType> set;
      for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {
        IdType w = A_indices[u];
        for (IdType v = B_indptr[w]; v < B_indptr[w + 1]; ++v)
          set.insert(B_indices[v]);
      }
      C_indptr_data[i] = set.size();
38
    }
39
  });
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
}

template <typename IdType>
int64_t ComputeIndptrInPlace(IdType* C_indptr_data, int64_t M) {
  int64_t nnz = 0;
  IdType len = 0;
  for (IdType i = 0; i < M; ++i) {
    len = C_indptr_data[i];
    C_indptr_data[i] = nnz;
    nnz += len;
  }
  C_indptr_data[M] = nnz;
  return nnz;
}

template <typename IdType, typename DType>
void ComputeIndicesAndData(
57
58
59
60
    const IdType* A_indptr, const IdType* A_indices, const IdType* A_eids,
    const DType* A_data, const IdType* B_indptr, const IdType* B_indices,
    const IdType* B_eids, const DType* B_data, const IdType* C_indptr_data,
    IdType* C_indices_data, DType* C_weights_data, int64_t M) {
61
62
63
64
65
66
67
68
69
70
71
  parallel_for(0, M, [=](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
      phmap::flat_hash_map<IdType, DType> map;
      for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {
        IdType w = A_indices[u];
        DType vA = A_data[A_eids ? A_eids[u] : u];
        for (IdType v = B_indptr[w]; v < B_indptr[w + 1]; ++v) {
          IdType t = B_indices[v];
          DType vB = B_data[B_eids ? B_eids[v] : v];
          map[t] += vA * vB;
        }
72
73
      }

74
75
76
77
78
79
      IdType v = C_indptr_data[i];
      for (auto it : map) {
        C_indices_data[v] = it.first;
        C_weights_data[v] = it.second;
        ++v;
      }
80
    }
81
  });
82
83
84
85
86
87
}

};  // namespace

template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
88
    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
89
    NDArray B_weights) {
90
91
  CHECK_EQ(A.num_cols, B.num_rows)
      << "A's number of columns must equal to B's number of rows";
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  const bool A_has_eid = !IsNullArray(A.data);
  const bool B_has_eid = !IsNullArray(B.data);
  const IdType* A_indptr = A.indptr.Ptr<IdType>();
  const IdType* A_indices = A.indices.Ptr<IdType>();
  const IdType* A_eids = A_has_eid ? A.data.Ptr<IdType>() : nullptr;
  const IdType* B_indptr = B.indptr.Ptr<IdType>();
  const IdType* B_indices = B.indices.Ptr<IdType>();
  const IdType* B_eids = B_has_eid ? B.data.Ptr<IdType>() : nullptr;
  const DType* A_data = A_weights.Ptr<DType>();
  const DType* B_data = B_weights.Ptr<DType>();
  const int64_t M = A.num_rows;
  const int64_t P = B.num_cols;

  IdArray C_indptr = IdArray::Empty({M + 1}, A.indptr->dtype, A.indptr->ctx);
  IdType* C_indptr_data = C_indptr.Ptr<IdType>();

108
109
  CountNNZPerRow<IdType>(
      A_indptr, A_indices, B_indptr, B_indices, C_indptr_data, M);
110
111
112
113
114
115
116
117
  int64_t nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M);
  // Allocate indices and weights array
  IdArray C_indices = IdArray::Empty({nnz}, A.indices->dtype, A.indices->ctx);
  NDArray C_weights = NDArray::Empty({nnz}, A_weights->dtype, A_weights->ctx);
  IdType* C_indices_data = C_indices.Ptr<IdType>();
  DType* C_weights_data = C_weights.Ptr<DType>();

  ComputeIndicesAndData<IdType, DType>(
118
      A_indptr, A_indices, A_eids, A_data, B_indptr, B_indices, B_eids, B_data,
119
120
      C_indptr_data, C_indices_data, C_weights_data, M);

121
  return {
122
123
      CSRMatrix(
          M, P, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
124
      C_weights};
125
126
}

127
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, float>(
128
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
129
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, float>(
130
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
131
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, double>(
132
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
133
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, double>(
134
135
136
137
    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);

};  // namespace aten
};  // namespace dgl