gather_mm.cc 3.82 KB
Newer Older
Israt Nisa's avatar
Israt Nisa committed
1
2
3
4
5
6
7
8
9
10
11
 /*!
 *  Copyright (c) 2020 by Contributors
 * \file kernel/cpu/gaher_mm.cc
 * \brief GatherMM C APIs and definitions.
 */
#include "./gather_mm.h"
#include <dgl/array.h>

namespace dgl {
namespace aten {

12
/*! \brief Generalized SegmentMM. */
13
template <int XPU, typename IdType, typename DType>
14
void SegmentMM(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
15
16
17
18
          const NDArray B,
          NDArray C,
          const NDArray seglen_A,
          bool a_trans, bool b_trans) {
19
20
21
    LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
}

22
template <int XPU, typename IdType, typename DType>
23
24
25
26
27
void SegmentMMBackwardB(const NDArray A,
                        const NDArray dC,
                        NDArray dB,
                        const NDArray seglen) {
    LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
Israt Nisa's avatar
Israt Nisa committed
28
29
30
}

/*! \brief Generalized GatherMM. */
31
template <int XPU, typename IdType, typename DType>
32
void GatherMM(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
33
34
35
          const NDArray B,
          NDArray C,
          const NDArray idx_a,
36
37
          const NDArray idx_b) {
    LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
38
39
40
}

/*! \brief Generalized GatherMM_scatter. */
41
template <int XPU, typename IdType, typename DType>
42
void GatherMMScatter(const NDArray A,
Israt Nisa's avatar
Israt Nisa committed
43
44
45
46
          const NDArray B,
          NDArray C,
          const NDArray idx_a,
          const NDArray idx_b,
47
48
          const NDArray idx_c) {
    LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
49
50
}

51
template void GatherMM<kDGLCPU, int32_t, float>(
Israt Nisa's avatar
Israt Nisa committed
52
    const NDArray A, const NDArray B, NDArray C,
53
    const NDArray idx_a, const NDArray idx_b);
54
template void GatherMM<kDGLCPU, int64_t, float>(
Israt Nisa's avatar
Israt Nisa committed
55
    const NDArray A, const NDArray B, NDArray C,
56
    const NDArray idx_a, const NDArray idx_b);
57
template void GatherMM<kDGLCPU, int32_t, double>(
Israt Nisa's avatar
Israt Nisa committed
58
    const NDArray A, const NDArray B, NDArray C,
59
    const NDArray idx_a, const NDArray idx_b);
60
template void GatherMM<kDGLCPU, int64_t, double>(
Israt Nisa's avatar
Israt Nisa committed
61
    const NDArray A, const NDArray B, NDArray C,
62
    const NDArray idx_a, const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
63

64
template void GatherMMScatter<kDGLCPU, int32_t, float>(
Israt Nisa's avatar
Israt Nisa committed
65
    const NDArray A, const NDArray B, NDArray C,
66
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
67
template void GatherMMScatter<kDGLCPU, int64_t, float>(
Israt Nisa's avatar
Israt Nisa committed
68
    const NDArray A, const NDArray B, NDArray C,
69
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
70
template void GatherMMScatter<kDGLCPU, int32_t, double>(
Israt Nisa's avatar
Israt Nisa committed
71
    const NDArray A, const NDArray B, NDArray C,
72
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
73
template void GatherMMScatter<kDGLCPU, int64_t, double>(
Israt Nisa's avatar
Israt Nisa committed
74
    const NDArray A, const NDArray B, NDArray C,
75
    const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
76

77
template void SegmentMM<kDGLCPU, int32_t, float>(
Israt Nisa's avatar
Israt Nisa committed
78
79
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
80
template void SegmentMM<kDGLCPU, int64_t, float>(
Israt Nisa's avatar
Israt Nisa committed
81
82
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
83
template void SegmentMM<kDGLCPU, int32_t, double>(
Israt Nisa's avatar
Israt Nisa committed
84
85
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);
86
template void SegmentMM<kDGLCPU, int64_t, double>(
Israt Nisa's avatar
Israt Nisa committed
87
88
89
    const NDArray A, const NDArray B, NDArray C,
    const NDArray seglen_A, bool a_trans, bool b_trans);

90
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
91
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
92
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
93
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
94
template void SegmentMMBackwardB<kDGLCPU, int32_t, double>(
95
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
96
template void SegmentMMBackwardB<kDGLCPU, int64_t, double>(
97
98
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
99
100
}  // namespace aten
}  // namespace dgl