partition_op.h 5.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*!
 *  Copyright (c) 2021 by Contributors
 * \file ndarray_partition.h
 * \brief DGL utilities for working with the partitioned NDArrays
 */


#ifndef DGL_PARTITION_PARTITION_OP_H_
#define DGL_PARTITION_PARTITION_OP_H_

#include <dgl/array.h>
#include <utility>

namespace dgl {
namespace partition {
namespace impl {

/**
19
20
21
22
23
24
 * @brief Create a permutation that groups indices by the part id when used for
 * slicing, via the remainder. That is, for the input indices A, find I
 * such that A[I] is grouped by part ID.
 *
 * For example, if we have the set of indices [3, 9, 2, 4, 1, 7] and two
 * partitions, the permutation vector would be [2, 3, 0, 1, 4, 5].
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
 *
 * @tparam XPU The type of device to run on.
 * @tparam IdType The type of the index.
 * @param array_size The total size of the partitioned array.
 * @param num_parts The number parts the array id divided into.
 * @param in_idx The array of indices to group by part id.
 *
 * @return The permutation to group the indices by part id, and the number of
 * indices in each part.
 */
template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder(
        int64_t array_size,
        int num_parts,
        IdArray in_idx);

/**
 * @brief Generate the set of local indices from the global indices, using
 * remainder. That is, for each index `i` in `global_idx`, the local index
 * is computed as `global_idx[i] / num_parts`.
 *
 * @tparam XPU The type of device to run on.
 * @tparam IdType The type of the index.
 * @param num_parts The number parts the array id divided into.
 * @param global_idx The array of global indices to map.
 *
 * @return The array of local indices.
 */
template <DLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder(
    int num_parts,
    IdArray global_idx);

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/**
 * @brief Generate the set of global indices from the local indices, using
 * remainder. That is, for each index `i` in `local_idx`, the global index
 * is computed as `local_idx[i] * num_parts + part_id`.
 *
 * @tparam XPU The type of device to run on.
 * @tparam IdType The type of the index.
 * @param num_parts The number parts the array id divided into.
 * @param local_idx The array of local indices to map.
 * @param part_id The id of the current part.
 *
 * @return The array of global indices.
 */
template <DLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder(
    int num_parts,
    IdArray local_idx,
    int part_id);

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
/**
 * @brief Create a permutation that groups indices by the part id when used for
 * slicing. That is, for the input indices A, find I such that A[I] is grouped
 * by part ID.
 *
 * For example, if we have a range of [0, 5, 10] and the set of indices
 * [3, 9, 2, 4, 1, 7], the permutation vector would be [0, 2, 3, 4, 1, 5].
 *
 * @tparam XPU The type of device to run on.
 * @tparam IdType The type of the index.
 * @tparam RangeType THe type of the range.
 * @param array_size The total size of the partitioned array.
 * @param num_parts The number parts the array id divided into.
 * @param range The exclusive prefix-sum, representing the range of rows
 * assigned to each partition. Must be on the same context as `in_idx`.
 * @param in_idx The array of indices to group by part id.
 *
 * @return The permutation to group the indices by part id, and the number of
 * indices in each part.
 */
template <DLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, IdArray>
GeneratePermutationFromRange(
        int64_t array_size,
        int num_parts,
        IdArray range,
        IdArray in_idx);

/**
 * @brief Generate the set of local indices from the global indices, using
 * remainder. That is, for each index `i` in `global_idx`, the local index
 * is computed as `global_idx[i] / num_parts`.
 *
 * @tparam XPU The type of device to run on.
 * @tparam IdType The type of the index.
 * @tparam RangeType THe type of the range.
 * @param num_parts The number parts the array id divided into.
 * @param range The exclusive prefix-sum, representing the range of rows
 * assigned to each partition. Must be on the same context as `global_idx`.
 * @param global_idx The array of global indices to map.
 *
 * @return The array of local indices.
 */
template <DLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange(
    int num_parts,
    IdArray range,
    IdArray global_idx);

/**
 * @brief Generate the set of global indices from the local indices, using
 * remainder. That is, for each index `i` in `local_idx`, the global index
 * is computed as `local_idx[i] * num_parts + part_id`.
 *
 * @tparam XPU The type of device to run on.
 * @tparam IdType The type of the index.
 * @tparam RangeType THe type of the range.
 * @param num_parts The number parts the array id divided into.
 * @param range The exclusive prefix-sum, representing the range of rows
 * assigned to each partition. Must be on the same context as `local_idx`.
 * @param local_idx The array of local indices to map.
 * @param part_id The id of the current part.
 *
 * @return The array of global indices.
 */
template <DLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange(
    int num_parts,
    IdArray range,
    IdArray local_idx,
    int part_id);


151
152
153
154
155
156

}  // namespace impl
}  // namespace partition
}  // namespace dgl

#endif  // DGL_PARTITION_PARTITION_OP_H_