gather_mm.cc 3.65 KB
Newer Older
1
/**
Israt Nisa's avatar
Israt Nisa committed
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file kernel/cpu/gaher_mm.cc
 * @brief GatherMM C APIs and definitions.
Israt Nisa's avatar
Israt Nisa committed
5
6
 */
#include "./gather_mm.h"
7

Israt Nisa's avatar
Israt Nisa committed
8
9
10
11
12
#include <dgl/array.h>

namespace dgl {
namespace aten {

13
/** @brief Generalized SegmentMM. */
14
template <int XPU, typename IdType, typename DType>
15
16
17
18
void SegmentMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans) {
  LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
19
20
}

21
template <int XPU, typename IdType, typename DType>
22
23
24
void SegmentMMBackwardB(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
  LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
Israt Nisa's avatar
Israt Nisa committed
25
26
}

27
/** @brief Generalized GatherMM. */
28
template <int XPU, typename IdType, typename DType>
29
30
31
32
void GatherMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b) {
  LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
33
34
}

35
/** @brief Generalized GatherMM_scatter. */
36
template <int XPU, typename IdType, typename DType>
37
38
39
40
void GatherMMScatter(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c) {
  LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
41
42
}

43
template void GatherMM<kDGLCPU, int32_t, float>(
44
45
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
46
template void GatherMM<kDGLCPU, int64_t, float>(
47
48
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
49
template void GatherMM<kDGLCPU, int32_t, double>(
50
51
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
52
template void GatherMM<kDGLCPU, int64_t, double>(
53
54
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
55

56
template void GatherMMScatter<kDGLCPU, int32_t, float>(
57
58
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
59
template void GatherMMScatter<kDGLCPU, int64_t, float>(
60
61
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
62
template void GatherMMScatter<kDGLCPU, int32_t, double>(
63
64
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
65
template void GatherMMScatter<kDGLCPU, int64_t, double>(
66
67
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
68

69
template void SegmentMM<kDGLCPU, int32_t, float>(
70
71
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
72
template void SegmentMM<kDGLCPU, int64_t, float>(
73
74
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
75
template void SegmentMM<kDGLCPU, int32_t, double>(
76
77
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
78
template void SegmentMM<kDGLCPU, int64_t, double>(
79
80
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
Israt Nisa's avatar
Israt Nisa committed
81

82
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
83
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
84
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
85
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
86
template void SegmentMMBackwardB<kDGLCPU, int32_t, double>(
87
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
88
template void SegmentMMBackwardB<kDGLCPU, int64_t, double>(
89
90
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
91
92
}  // namespace aten
}  // namespace dgl