"docs/vscode:/vscode.git/clone" did not exist on "df476d9f63891406db2d531aa5faf195193e3354"
gather_mm.cc 4.91 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
Israt Nisa's avatar
Israt Nisa committed
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file kernel/cpu/gaher_mm.cc
 * @brief GatherMM C APIs and definitions.
Israt Nisa's avatar
Israt Nisa committed
6
 */
sangwzh's avatar
sangwzh committed
7
#include "gather_mm.h"
8

Israt Nisa's avatar
Israt Nisa committed
9
10
11
12
13
#include <dgl/array.h>

namespace dgl {
namespace aten {

14
/** @brief Generalized SegmentMM. */
15
template <int XPU, typename IdType, typename DType>
16
17
18
19
void SegmentMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans) {
  LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
20
21
}

22
template <int XPU, typename IdType, typename DType>
23
24
25
void SegmentMMBackwardB(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
  LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
Israt Nisa's avatar
Israt Nisa committed
26
27
}

28
/** @brief Generalized GatherMM. */
29
template <int XPU, typename IdType, typename DType>
30
31
32
33
void GatherMM(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b) {
  LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
34
35
}

36
/** @brief Generalized GatherMM_scatter. */
37
template <int XPU, typename IdType, typename DType>
38
39
40
41
void GatherMMScatter(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c) {
  LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
Israt Nisa's avatar
Israt Nisa committed
42
43
}

44
45
46
47
48
49
template void GatherMM<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
50
template void GatherMM<kDGLCPU, int32_t, float>(
51
52
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
53
template void GatherMM<kDGLCPU, int64_t, float>(
54
55
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
56
template void GatherMM<kDGLCPU, int32_t, double>(
57
58
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
59
template void GatherMM<kDGLCPU, int64_t, double>(
60
61
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b);
Israt Nisa's avatar
Israt Nisa committed
62

63
64
65
66
67
68
template void GatherMMScatter<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
69
template void GatherMMScatter<kDGLCPU, int32_t, float>(
70
71
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
72
template void GatherMMScatter<kDGLCPU, int64_t, float>(
73
74
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
75
template void GatherMMScatter<kDGLCPU, int32_t, double>(
76
77
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
78
template void GatherMMScatter<kDGLCPU, int64_t, double>(
79
80
    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
    const NDArray idx_b, const NDArray idx_c);
Israt Nisa's avatar
Israt Nisa committed
81

82
83
84
85
86
87
template void SegmentMM<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
88
template void SegmentMM<kDGLCPU, int32_t, float>(
89
90
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
91
template void SegmentMM<kDGLCPU, int64_t, float>(
92
93
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
94
template void SegmentMM<kDGLCPU, int32_t, double>(
95
96
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
97
template void SegmentMM<kDGLCPU, int64_t, double>(
98
99
    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
    bool a_trans, bool b_trans);
Israt Nisa's avatar
Israt Nisa committed
100

101
102
103
104
template void SegmentMMBackwardB<kDGLCPU, int32_t, BFloat16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, BFloat16>(
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
105
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
106
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
107
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
108
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
109
template void SegmentMMBackwardB<kDGLCPU, int32_t, double>(
110
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
111
template void SegmentMMBackwardB<kDGLCPU, int64_t, double>(
112
113
    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);

Israt Nisa's avatar
Israt Nisa committed
114
115
}  // namespace aten
}  // namespace dgl