"src/vscode:/vscode.git/clone" did not exist on "5a2e0f715c4700959971b53c2634132917c00332"
gather_mm.cc 5.61 KB
Newer Older
Israt Nisa's avatar
Israt Nisa committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
 /*!
 *  Copyright (c) 2020 by Contributors
 * \file kernel/cpu/gaher_mm.cc
 * \brief GatherMM C APIs and definitions.
 */
#include "./gather_mm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

#define SWITCH_BITS(bits, DType, ...)                           \
  do {                                                          \
    if ((bits) == 16 || (bits) == 32) {                         \
      typedef float DType;                                      \
      { __VA_ARGS__ }                                           \
    } else if ((bits) == 64) {                                  \
      typedef double DType;                                     \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
      LOG(FATAL) << "Data type not recognized with bits " << bits; \
    }                                                           \
  } while (0)


26
/*! \brief Generalized SegmentMM. */
Israt Nisa's avatar
Israt Nisa committed
27
template <int XPU, typename IdType, int bits>
28
void SegmentMM(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
29
30
31
32
          const NDArray B,
          NDArray C,
          const NDArray seglen_A,
          bool a_trans, bool b_trans) {
33
34
35
36
37
38
39
40
41
    LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
}

template <int XPU, typename IdType, int bits>
void SegmentMMBackwardB(const NDArray A,
                        const NDArray dC,
                        NDArray dB,
                        const NDArray seglen) {
    LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
Israt Nisa's avatar
Israt Nisa committed
42
43
44
45
}

/*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits>
46
void GatherMM(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
47
48
49
          const NDArray B,
          NDArray C,
          const NDArray idx_a,
50
51
          const NDArray idx_b) {
    LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
52
53
54
55
}

/*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits>
56
void GatherMMScatter(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
57
58
59
60
          const NDArray B,
          NDArray C,
          const NDArray idx_a,
          const NDArray idx_b,
61
62
          const NDArray idx_c) {
    LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
63
64
}

65
template void GatherMM<kDLCPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
66
    const NDArray A, const NDArray B, NDArray C,
67
68
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
69
    const NDArray A, const NDArray B, NDArray C,
70
71
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
72
    const NDArray A, const NDArray B, NDArray C,
73
74
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
75
    const NDArray A, const NDArray B, NDArray C,
76
77
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
78
    const NDArray A, const NDArray B, NDArray C,
79
80
    const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
81
    const NDArray A, const NDArray B, NDArray C,
82
    const NDArray idx_a, const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
83

84
template void GatherMMScatter<kDLCPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
85
    const NDArray A, const NDArray B, NDArray C,
86
87
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
88
    const NDArray A, const NDArray B, NDArray C,
89
90
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
91
    const NDArray A, const NDArray B, NDArray C,
92
93
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
94
    const NDArray A, const NDArray B, NDArray C,
95
96
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
97
    const NDArray A, const NDArray B, NDArray C,
98
99
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
100
    const NDArray A, const NDArray B, NDArray C,
101
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
102

103
template void SegmentMM<kDLCPU, int32_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
104
105
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
106
template void SegmentMM<kDLCPU, int64_t, 16>(
Israt Nisa's avatar
Israt Nisa committed
107
108
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
109
template void SegmentMM<kDLCPU, int32_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
110
111
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
112
template void SegmentMM<kDLCPU, int64_t, 32>(
Israt Nisa's avatar
Israt Nisa committed
113
114
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
115
template void SegmentMM<kDLCPU, int32_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
116
117
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
118
template void SegmentMM<kDLCPU, int64_t, 64>(
Israt Nisa's avatar
Israt Nisa committed
119
120
121
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);

122
123
124
125
126
127
128
129
130
131
132
133
134
template void SegmentMMBackwardB<kDLCPU, int32_t, 16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 32>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 32>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 64>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 64>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
135
136
}  // namespace aten
}  // namespace dgl