ndarray_partition.cc 3.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/*!
 *  Copyright (c) 2021 by Contributors
 * \file ndarray_partition.cc
 * \brief DGL utilities for working with the partitioned NDArrays
 */

#include "ndarray_partition.h"

#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <utility>
#include <memory>

#include "partition_op.h"

using namespace dgl::runtime;

namespace dgl {
namespace partition {

NDArrayPartition::NDArrayPartition(
    const int64_t array_size, const int num_parts) :
  array_size_(array_size),
  num_parts_(num_parts) {
}

int64_t NDArrayPartition::ArraySize() const {
  return array_size_;
}

int NDArrayPartition::NumParts() const {
  return num_parts_;
}


class RemainderPartition : public NDArrayPartition {
 public:
  RemainderPartition(
      const int64_t array_size, const int num_parts) :
    NDArrayPartition(array_size, num_parts) {
    // do nothing
  }

  std::pair<IdArray, NDArray>
  GeneratePermutation(
      IdArray in_idx) const override {
    auto ctx = in_idx->ctx;

#ifdef DGL_USE_CUDA
    if (ctx.device_type == kDLGPU) {
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
        return impl::GeneratePermutationFromRemainder<kDLGPU, IdType>(
            ArraySize(), NumParts(), in_idx);
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
        "implemented.";
    // should be unreachable
    return std::pair<IdArray, NDArray>{};
  }

  IdArray MapToLocal(
      IdArray in_idx) const override {
    auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
    if (ctx.device_type == kDLGPU) {
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
        return impl::MapToLocalFromRemainder<kDLGPU, IdType>(
            NumParts(), in_idx);
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
        "implemented.";
    // should be unreachable
    return IdArray{};
  }
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

  IdArray MapToGlobal(
      IdArray in_idx,
      const int part_id) const override {
    auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
    if (ctx.device_type == kDLGPU) {
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
        return impl::MapToGlobalFromRemainder<kDLGPU, IdType>(
            NumParts(), in_idx, part_id);
      });
    }
#endif

    LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
        "implemented.";
    // should be unreachable
    return IdArray{};
  }

  int64_t PartSize(const int part_id) const override {
    CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id << ") for "
        "partition of size " << NumParts() << ".";
    return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts());
  }
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
};

NDArrayPartitionRef CreatePartitionRemainderBased(
    const int64_t array_size,
    const int num_parts) {
  return NDArrayPartitionRef(std::make_shared<RemainderPartition>(
          array_size, num_parts));
}

DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  int64_t array_size = args[0];
  int num_parts = args[1];

  *rv = CreatePartitionRemainderBased(array_size, num_parts);
});

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  NDArrayPartitionRef part = args[0];
  int part_id = args[1];

  *rv = part->PartSize(part_id);
});

DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToLocal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  NDArrayPartitionRef part = args[0];
  IdArray idxs = args[1];

  *rv = part->MapToLocal(idxs);
});

DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToGlobal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  NDArrayPartitionRef part = args[0];
  IdArray idxs = args[1];
  const int part_id = args[2];

  *rv = part->MapToGlobal(idxs, part_id);
});


149
150
}  // namespace partition
}  // namespace dgl