gather_mm.cc 5.63 KB
Newer Older
Israt Nisa's avatar
Israt Nisa committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
 /*!
 *  Copyright (c) 2020 by Contributors
 * \file kernel/cpu/gaher_mm.cc
 * \brief GatherMM C APIs and definitions.
 */
#include "./gather_mm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

#define SWITCH_BITS(bits, DType, ...)                           \
  do {                                                          \
    if ((bits) == 16 || (bits) == 32) {                         \
      typedef float DType;                                      \
      { __VA_ARGS__ }                                           \
    } else if ((bits) == 64) {                                  \
      typedef double DType;                                     \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
      LOG(FATAL) << "Data type not recognized with bits " << bits; \
    }                                                           \
  } while (0)


26
/*! \brief Generalized SegmentMM. */
Israt Nisa's avatar
Israt Nisa committed
27
template <int XPU, typename IdType, int bits>
28
void SegmentMM(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
29
30
31
32
          const NDArray B,
          NDArray C,
          const NDArray seglen_A,
          bool a_trans, bool b_trans) {
33
34
35
36
37
38
39
40
41
    LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
}

template <int XPU, typename IdType, int bits>
void SegmentMMBackwardB(const NDArray A,
                        const NDArray dC,
                        NDArray dB,
                        const NDArray seglen) {
    LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
Israt Nisa's avatar
Israt Nisa committed
42
43
44
45
}

/*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits>
46
void GatherMM(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
47
48
49
          const NDArray B,
          NDArray C,
          const NDArray idx_a,
50
51
          const NDArray idx_b) {
    LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
52
53
54
55
}

/*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits>
56
void GatherMMScatter(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
57
58
59
60
          const NDArray B,
          NDArray C,
          const NDArray idx_a,
          const NDArray idx_b,
61
62
          const NDArray idx_c) {
    LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
63
64
}

65
template void GatherMM<kDGLCPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
66
    const NDArray A, const NDArray B, NDArray C,
67
    const NDArray idx_a, const NDArray idx_b);
68
template void GatherMM<kDGLCPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
69
    const NDArray A, const NDArray B, NDArray C,
70
    const NDArray idx_a, const NDArray idx_b);
71
template void GatherMM<kDGLCPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
72
    const NDArray A, const NDArray B, NDArray C,
73
    const NDArray idx_a, const NDArray idx_b);
74
template void GatherMM<kDGLCPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
75
    const NDArray A, const NDArray B, NDArray C,
76
    const NDArray idx_a, const NDArray idx_b);
77
template void GatherMM<kDGLCPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
78
    const NDArray A, const NDArray B, NDArray C,
79
    const NDArray idx_a, const NDArray idx_b);
80
template void GatherMM<kDGLCPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
81
    const NDArray A, const NDArray B, NDArray C,
82
    const NDArray idx_a, const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
83

84
template void GatherMMScatter<kDGLCPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
85
    const NDArray A, const NDArray B, NDArray C,
86
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
87
template void GatherMMScatter<kDGLCPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
88
    const NDArray A, const NDArray B, NDArray C,
89
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
90
template void GatherMMScatter<kDGLCPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
91
    const NDArray A, const NDArray B, NDArray C,
92
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
93
template void GatherMMScatter<kDGLCPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
94
    const NDArray A, const NDArray B, NDArray C,
95
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
96
template void GatherMMScatter<kDGLCPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
97
    const NDArray A, const NDArray B, NDArray C,
98
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
99
template void GatherMMScatter<kDGLCPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
100
    const NDArray A, const NDArray B, NDArray C,
101
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
102

103
template void SegmentMM<kDGLCPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
104
105
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
106
template void SegmentMM<kDGLCPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
107
108
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
109
template void SegmentMM<kDGLCPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
110
111
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
112
template void SegmentMM<kDGLCPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
113
114
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
115
template void SegmentMM<kDGLCPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
116
117
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
118
template void SegmentMM<kDGLCPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
119
120
121
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);

122
template void SegmentMMBackwardB<kDGLCPU, int32_t, 16>(
123
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
124
template void SegmentMMBackwardB<kDGLCPU, int64_t, 16>(
125
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
126
template void SegmentMMBackwardB<kDGLCPU, int32_t, 32>(
127
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
128
template void SegmentMMBackwardB<kDGLCPU, int64_t, 32>(
129
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
130
template void SegmentMMBackwardB<kDGLCPU, int32_t, 64>(
131
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
132
template void SegmentMMBackwardB<kDGLCPU, int64_t, 64>(
133
134
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
135
136
}  // namespace aten
}  // namespace dgl