ndarray_partition.cc 7.97 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file ndarray_partition.cc
 * @brief DGL utilities for working with the partitioned NDArrays
5
6
7
8
9
 */

#include "ndarray_partition.h"

#include <dgl/runtime/packed_func.h>
10
11
#include <dgl/runtime/registry.h>

12
#include <memory>
13
#include <utility>
14
15
16
17
18
19
20
21
22

#include "partition_op.h"

using namespace dgl::runtime;

namespace dgl {
namespace partition {

NDArrayPartition::NDArrayPartition(
23
24
    const int64_t array_size, const int num_parts)
    : array_size_(array_size), num_parts_(num_parts) {}
25

26
int64_t NDArrayPartition::ArraySize() const { return array_size_; }
27

28
int NDArrayPartition::NumParts() const { return num_parts_; }
29
30
31

class RemainderPartition : public NDArrayPartition {
 public:
32
33
  RemainderPartition(const int64_t array_size, const int num_parts)
      : NDArrayPartition(array_size, num_parts) {
34
35
36
    // do nothing
  }

37
  std::pair<IdArray, NDArray> GeneratePermutation(
38
39
      IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
40
    auto ctx = in_idx->ctx;
41
    if (ctx.device_type == kDGLCUDA) {
42
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
43
        return impl::GeneratePermutationFromRemainder<kDGLCUDA, IdType>(
44
45
46
47
48
49
            ArraySize(), NumParts(), in_idx);
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
50
                  "implemented.";
51
52
53
54
    // should be unreachable
    return std::pair<IdArray, NDArray>{};
  }

55
  IdArray MapToLocal(IdArray in_idx) const override {
56
#ifdef DGL_USE_CUDA
57
    auto ctx = in_idx->ctx;
58
    if (ctx.device_type == kDGLCUDA) {
59
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
60
        return impl::MapToLocalFromRemainder<kDGLCUDA, IdType>(
61
62
63
64
65
66
            NumParts(), in_idx);
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
67
                  "implemented.";
68
69
70
    // should be unreachable
    return IdArray{};
  }
71

72
  IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
73
#ifdef DGL_USE_CUDA
74
    auto ctx = in_idx->ctx;
75
    if (ctx.device_type == kDGLCUDA) {
76
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
77
        return impl::MapToGlobalFromRemainder<kDGLCUDA, IdType>(
78
79
80
81
82
83
            NumParts(), in_idx, part_id);
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
84
                  "implemented.";
85
86
87
88
89
    // should be unreachable
    return IdArray{};
  }

  int64_t PartSize(const int part_id) const override {
90
91
92
93
    CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
                                  << ") for "
                                     "partition of size "
                                  << NumParts() << ".";
94
95
    return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts());
  }
96
97
};

98
99
class RangePartition : public NDArrayPartition {
 public:
100
101
102
103
104
105
106
107
  RangePartition(const int64_t array_size, const int num_parts, IdArray range)
      : NDArrayPartition(array_size, num_parts),
        range_(range),
        // We also need a copy of the range on the CPU, to compute partition
        // sizes. We require the input range on the GPU, as if we have multiple
        // GPUs, we can't know which is the proper one to copy the array to, but
        // we have only one CPU context, and can safely copy the array to that.
        range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {
108
    auto ctx = range->ctx;
109
    if (ctx.device_type != kDGLCUDA) {
110
111
112
      LOG(FATAL) << "The range for an NDArrayPartition is only supported "
                    " on GPUs. Transfer the range to the target device before "
                    "creating the partition.";
113
114
115
    }
  }

116
  std::pair<IdArray, NDArray> GeneratePermutation(
117
118
      IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
119
    auto ctx = in_idx->ctx;
120
    if (ctx.device_type == kDGLCUDA) {
121
122
123
      if (ctx.device_type != range_->ctx.device_type ||
          ctx.device_id != range_->ctx.device_id) {
        LOG(FATAL) << "The range for the NDArrayPartition and the input "
124
125
                      "array must be on the same device: "
                   << ctx << " vs. " << range_->ctx;
126
127
128
      }
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
        ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
129
130
          return impl::GeneratePermutationFromRange<
              kDGLCUDA, IdType, RangeType>(
131
132
133
134
135
136
137
              ArraySize(), NumParts(), range_, in_idx);
        });
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
138
                  "implemented.";
139
140
141
142
    // should be unreachable
    return std::pair<IdArray, NDArray>{};
  }

143
  IdArray MapToLocal(IdArray in_idx) const override {
144
#ifdef DGL_USE_CUDA
145
    auto ctx = in_idx->ctx;
146
    if (ctx.device_type == kDGLCUDA) {
147
148
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
        ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
149
          return impl::MapToLocalFromRange<kDGLCUDA, IdType, RangeType>(
150
151
152
153
154
155
156
              NumParts(), range_, in_idx);
        });
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
157
                  "implemented.";
158
159
160
161
    // should be unreachable
    return IdArray{};
  }

162
  IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
163
#ifdef DGL_USE_CUDA
164
    auto ctx = in_idx->ctx;
165
    if (ctx.device_type == kDGLCUDA) {
166
167
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
        ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
168
          return impl::MapToGlobalFromRange<kDGLCUDA, IdType, RangeType>(
169
170
171
172
173
174
175
              NumParts(), range_, in_idx, part_id);
        });
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
176
                  "implemented.";
177
178
179
180
181
    // should be unreachable
    return IdArray{};
  }

  int64_t PartSize(const int part_id) const override {
182
183
184
185
    CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
                                  << ") for "
                                     "partition of size "
                                  << NumParts() << ".";
186
    int64_t part_size = -1;
187
    ATEN_ID_TYPE_SWITCH(range_cpu_->dtype, RangeType, {
188
189
      const RangeType* const ptr =
          static_cast<const RangeType*>(range_cpu_->data);
190
      part_size = ptr[part_id + 1] - ptr[part_id];
191
    });
192
    return part_size;
193
194
195
196
197
198
199
  }

 private:
  IdArray range_;
  IdArray range_cpu_;
};

200
NDArrayPartitionRef CreatePartitionRemainderBased(
201
202
203
    const int64_t array_size, const int num_parts) {
  return NDArrayPartitionRef(
      std::make_shared<RemainderPartition>(array_size, num_parts));
204
205
}

206
NDArrayPartitionRef CreatePartitionRangeBased(
207
208
209
    const int64_t array_size, const int num_parts, IdArray range) {
  return NDArrayPartitionRef(
      std::make_shared<RangePartition>(array_size, num_parts, range));
210
211
}

212
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
213
214
215
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int64_t array_size = args[0];
      int num_parts = args[1];
216

217
218
      *rv = CreatePartitionRemainderBased(array_size, num_parts);
    });
219

220
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRangeBased")
221
222
223
224
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int64_t array_size = args[0];
      const int num_parts = args[1];
      IdArray range = args[2];
225

226
227
      *rv = CreatePartitionRangeBased(array_size, num_parts, range);
    });
228

229
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize")
230
231
232
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArrayPartitionRef part = args[0];
      int part_id = args[1];
233

234
235
      *rv = part->PartSize(part_id);
    });
236
237

DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToLocal")
238
239
240
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArrayPartitionRef part = args[0];
      IdArray idxs = args[1];
241

242
243
      *rv = part->MapToLocal(idxs);
    });
244
245

DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToGlobal")
246
247
248
249
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArrayPartitionRef part = args[0];
      IdArray idxs = args[1];
      const int part_id = args[2];
250

251
252
      *rv = part->MapToGlobal(idxs, part_id);
    });
253

254
255
}  // namespace partition
}  // namespace dgl