gather_mm.cc 4.86 KB
Newer Older
1
/**
Israt Nisa's avatar
Israt Nisa committed
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file kernel/cpu/gaher_mm.cc
 * @brief GatherMM C APIs and definitions.
Israt Nisa's avatar
Israt Nisa committed
5
6
 */
#include "./gather_mm.h"
7

Israt Nisa's avatar
Israt Nisa committed
8
9
10
11
12
#include <dgl/array.h>

namespace dgl {
namespace aten {

13
/** @brief Generalized SegmentMM. */
14
template <int XPU, typename IdType, typename DType>
15
16
17
18
void SegmentMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans) {
  LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
19
20
}

21
template <int XPU, typename IdType, typename DType>
22
23
24
void SegmentMMBackwardB(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
  LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
Israt Nisa's avatar
Israt Nisa committed
25
26
}

27
/** @brief Generalized GatherMM. */
28
template <int XPU, typename IdType, typename DType>
29
30
31
32
void GatherMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b) {
  LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
33
34
}

35
/** @brief Generalized GatherMM_scatter. */
36
template <int XPU, typename IdType, typename DType>
37
38
39
40
void GatherMMScatter(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c) {
  LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
41
42
}

43
44
45
46
47
48
template void GatherMM<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
49
template void GatherMM<kDGLCPU, int32_t, float>(
50
51
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
52
template void GatherMM<kDGLCPU, int64_t, float>(
53
54
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
55
template void GatherMM<kDGLCPU, int32_t, double>(
56
57
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
58
template void GatherMM<kDGLCPU, int64_t, double>(
59
60
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
61

62
63
64
65
66
67
template void GatherMMScatter<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
68
template void GatherMMScatter<kDGLCPU, int32_t, float>(
69
70
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
71
template void GatherMMScatter<kDGLCPU, int64_t, float>(
72
73
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
74
template void GatherMMScatter<kDGLCPU, int32_t, double>(
75
76
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
77
template void GatherMMScatter<kDGLCPU, int64_t, double>(
78
79
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
80

81
82
83
84
85
86
template void SegmentMM<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
87
template void SegmentMM<kDGLCPU, int32_t, float>(
88
89
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
90
template void SegmentMM<kDGLCPU, int64_t, float>(
91
92
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
93
template void SegmentMM<kDGLCPU, int32_t, double>(
94
95
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
96
template void SegmentMM<kDGLCPU, int64_t, double>(
97
98
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
Israt Nisa's avatar
Israt Nisa committed
99

100
101
102
103
template void SegmentMMBackwardB<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
104
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
105
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
106
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
107
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
108
template void SegmentMMBackwardB<kDGLCPU, int32_t, double>(
109
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
110
template void SegmentMMBackwardB<kDGLCPU, int64_t, double>(
111
112
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
113
114
}  // namespace aten
}  // namespace dgl