partition_op.cu 8.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*!
 *  Copyright (c) 2021 by Contributors
 * \file ndarray_partition.h
 * \brief Operations on partition implemented in CUDA.
 */

#include "../partition_op.h"

#include <dgl/runtime/device_api.h>

#include "../../array/cuda/dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/workspace.h"

using namespace dgl::runtime;

namespace dgl {
namespace partition {
namespace impl {

template<typename IdType> __global__ void _MapProcByRemainder(
    const IdType * const index,
    const int64_t num_index,
    const int64_t num_proc,
    IdType * const proc_id) {
26
  assert(num_index <= gridDim.x*blockDim.x);
27
28
29
30
31
32
33
34
35
36
37
38
39
  const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;

  if (idx < num_index) {
    proc_id[idx] = index[idx] % num_proc;
  }
}

template<typename IdType>
__global__ void _MapProcByMaskRemainder(
    const IdType * const index,
    const int64_t num_index,
    const IdType mask,
    IdType * const proc_id) {
40
  assert(num_index <= gridDim.x*blockDim.x);
41
42
43
44
45
46
47
48
49
50
51
52
53
  const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;

  if (idx < num_index) {
    proc_id[idx] = index[idx] & mask;
  }
}

template<typename IdType>
__global__ void _MapLocalIndexByRemainder(
    const IdType * const in,
    IdType * const out,
    const int64_t num_items,
    const int comm_size) {
54
  assert(num_items <= gridDim.x*blockDim.x);
55
56
57
58
59
60
61
  const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;

  if (idx < num_items) {
    out[idx] = in[idx] / comm_size;
  }
}

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
template<typename IdType>
__global__ void _MapGlobalIndexByRemainder(
    const IdType * const in,
    IdType * const out,
    const int part_id,
    const int64_t num_items,
    const int comm_size) {
  assert(num_items <= gridDim.x*blockDim.x);
  const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;

  assert(part_id < comm_size);

  if (idx < num_items) {
    out[idx] = (in[idx] * comm_size) + part_id;
  }
}

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, NDArray>
GeneratePermutationFromRemainder(
        int64_t array_size,
        int num_parts,
        IdArray in_idx) {
  std::pair<IdArray, NDArray> result;

  const auto& ctx = in_idx->ctx;
  auto device = DeviceAPI::Get(ctx);
  cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;

  const int64_t num_in = in_idx->shape[0];

  CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts <<
      ") must be at least 1.";
  if (num_parts == 1) {
    // no permutation
    result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
    result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx);

    return result;
  }

103
  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8);
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
  result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx);
  int64_t * out_counts = static_cast<int64_t*>(result.second->data);
  if (num_in == 0) {
    // now that we've zero'd out_counts, nothing left to do for an empty
    // mapping
    return result;
  }

  const int64_t part_bits =
      static_cast<int64_t>(std::ceil(std::log2(num_parts)));

  // First, generate a mapping of indexes to processors
  Workspace<IdType> proc_id_in(device, ctx, num_in);
  {
    const dim3 block(256);
    const dim3 grid((num_in+block.x-1)/block.x);

    if (num_parts < (1 << part_bits)) {
      // num_parts is not a power of 2
      CUDA_KERNEL_CALL(_MapProcByRemainder, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data),
          num_in,
          num_parts,
          proc_id_in.get());
    } else {
      // num_parts is a power of 2
      CUDA_KERNEL_CALL(_MapProcByMaskRemainder, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data),
          num_in,
          static_cast<IdType>(num_parts-1),  // bit mask
          proc_id_in.get());
    }
  }

  // then create a permutation array that groups processors together by
  // performing a radix sort
  Workspace<IdType> proc_id_out(device, ctx, num_in);
  IdType * perm_out = static_cast<IdType*>(result.first->data);
  {
    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx);

    size_t sort_workspace_size;
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size,
        proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
        num_in, 0, part_bits, stream));

    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size,
        proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
        num_in, 0, part_bits, stream));
  }
  // explicitly free so workspace can be re-used
  proc_id_in.free();

  // perform a histogram and then prefixsum on the sorted proc_id vector

  // Count the number of values to be sent to each processor
  {
    using AtomicCount = unsigned long long; // NOLINT
    static_assert(sizeof(AtomicCount) == sizeof(*out_counts),
        "AtomicCount must be the same width as int64_t for atomicAdd "
        "in cub::DeviceHistogram::HistogramEven() to work");

    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
    // add a compile time check against the cub version to allow
    // num_in > (2 << 31).
    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max())) <<
        "number of values to insert into histogram must be less than max "
        "value of int.";

    size_t hist_workspace_size;
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
        nullptr,
        hist_workspace_size,
        proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts),
        num_parts+1,
        static_cast<IdType>(0),
        static_cast<IdType>(num_parts+1),
        static_cast<int>(num_in),
        stream));

    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
        hist_workspace.get(),
        hist_workspace_size,
        proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts),
        num_parts+1,
        static_cast<IdType>(0),
        static_cast<IdType>(num_parts+1),
        static_cast<int>(num_in),
        stream));
  }

  return result;
}


template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int32_t>(
        int64_t array_size,
        int num_parts,
        IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int64_t>(
        int64_t array_size,
        int num_parts,
        IdArray in_idx);


template <DLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder(
    const int num_parts,
    IdArray global_idx) {
  const auto& ctx = global_idx->ctx;
  cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
  if (num_parts > 1) {
    IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
        sizeof(IdType)*8);

    const dim3 block(128);
    const dim3 grid((global_idx->shape[0] +block.x-1)/block.x);

    CUDA_KERNEL_CALL(
        _MapLocalIndexByRemainder,
        grid,
        block,
        0,
        stream,
        static_cast<const IdType*>(global_idx->data),
        static_cast<IdType*>(local_idx->data),
        global_idx->shape[0],
        num_parts);

    return local_idx;
  } else {
    // no mapping to be done
    return global_idx;
  }
245
246
247
248
249
250
251
252
253
254
255
}

template IdArray
MapToLocalFromRemainder<kDLGPU, int32_t>(
        int num_parts,
        IdArray in_idx);
template IdArray
MapToLocalFromRemainder<kDLGPU, int64_t>(
        int num_parts,
        IdArray in_idx);

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
template <DLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder(
    const int num_parts,
    IdArray local_idx,
    const int part_id) {
  CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id <<
      "/" << num_parts;
  CHECK_GE(part_id, 0) << "Invalid partition id " << part_id <<
      "/" << num_parts;

  const auto& ctx = local_idx->ctx;
  cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;

  if (num_parts > 1) {
    IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
        sizeof(IdType)*8);

    const dim3 block(128);
    const dim3 grid((local_idx->shape[0] +block.x-1)/block.x);

    CUDA_KERNEL_CALL(
        _MapGlobalIndexByRemainder,
        grid,
        block,
        0,
        stream,
        static_cast<const IdType*>(local_idx->data),
        static_cast<IdType*>(global_idx->data),
        part_id,
        global_idx->shape[0],
        num_parts);

    return global_idx;
  } else {
    // no mapping to be done
    return local_idx;
  }
}

template IdArray
MapToGlobalFromRemainder<kDLGPU, int32_t>(
        int num_parts,
        IdArray in_idx,
        int part_id);
template IdArray
MapToGlobalFromRemainder<kDLGPU, int64_t>(
        int num_parts,
        IdArray in_idx,
        int part_id);


307
308
309
310

}  // namespace impl
}  // namespace partition
}  // namespace dgl