cuda_to_block.cu 9.18 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 *  Copyright 2020-2021 Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file graph/transform/cuda/cuda_to_block.cu
 * @brief Functions to convert a set of edges into a graph block with local
18
 * ids.
19
20
 *
 * Tested via python wrapper: python/dgl/path/to/to_block.py
21
22
23
 */

#include <cuda_runtime.h>
24
25
#include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
26
#include <dgl/runtime/tensordispatch.h>
27

28
29
#include <algorithm>
#include <memory>
30
#include <utility>
31
32
33

#include "../../../runtime/cuda/cuda_common.h"
#include "../../heterograph.h"
34
#include "../to_block.h"
35
#include "cuda_map_edges.cuh"
36
37
38

using namespace dgl::aten;
using namespace dgl::runtime::cuda;
39
using namespace dgl::transform::cuda;
40
using TensorDispatcher = dgl::runtime::TensorDispatcher;
41
42
43
44
45
46

namespace dgl {
namespace transform {

namespace {

47
template <typename IdType>
48
49
class DeviceNodeMapMaker {
 public:
50
51
52
53
  explicit DeviceNodeMapMaker(const std::vector<int64_t>& maxNodesPerType)
      : max_num_nodes_(0) {
    max_num_nodes_ =
        *std::max_element(maxNodesPerType.begin(), maxNodesPerType.end());
54
55
56
  }

  /**
57
   * @brief This function builds node maps for each node type, preserving the
58
59
60
   * order of the input nodes. Here it is assumed the lhs_nodes are not unique,
   * and thus a unique list is generated.
   *
61
62
63
64
65
66
   * @param lhs_nodes The set of source input nodes.
   * @param rhs_nodes The set of destination input nodes.
   * @param node_maps The node maps to be constructed.
   * @param count_lhs_device The number of unique source nodes (on the GPU).
   * @param lhs_device The unique source nodes (on the GPU).
   * @param stream The stream to operate on.
67
   */
68
69
70
  void Make(
      const std::vector<IdArray>& lhs_nodes,
      const std::vector<IdArray>& rhs_nodes,
71
72
      DeviceNodeMap<IdType>* const node_maps, int64_t* const count_lhs_device,
      std::vector<IdArray>* const lhs_device, cudaStream_t stream) {
73
74
75
    const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();

    CUDA_CALL(cudaMemsetAsync(
76
        count_lhs_device, 0, num_ntypes * sizeof(*count_lhs_device), stream));
77
78
79
80
81
82

    // possibly dublicate lhs nodes
    const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());
    for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
      const IdArray& nodes = lhs_nodes[ntype];
      if (nodes->shape[0] > 0) {
83
        CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
84
        node_maps->LhsHashTable(ntype).FillWithDuplicates(
85
86
            nodes.Ptr<IdType>(), nodes->shape[0],
            (*lhs_device)[ntype].Ptr<IdType>(), count_lhs_device + ntype,
87
88
89
90
91
92
93
94
95
96
            stream);
      }
    }

    // unique rhs nodes
    const int64_t rhs_num_ntypes = static_cast<int64_t>(rhs_nodes.size());
    for (int64_t ntype = 0; ntype < rhs_num_ntypes; ++ntype) {
      const IdArray& nodes = rhs_nodes[ntype];
      if (nodes->shape[0] > 0) {
        node_maps->RhsHashTable(ntype).FillWithUnique(
97
            nodes.Ptr<IdType>(), nodes->shape[0], stream);
98
99
100
101
      }
    }
  }

102
  /**
103
   * @brief This function builds node maps for each node type, preserving the
104
105
106
   * order of the input nodes. Here it is assumed both lhs_nodes and rhs_nodes
   * are unique.
   *
107
108
109
110
   * @param lhs_nodes The set of source input nodes.
   * @param rhs_nodes The set of destination input nodes.
   * @param node_maps The node maps to be constructed.
   * @param stream The stream to operate on.
111
   */
112
113
114
  void Make(
      const std::vector<IdArray>& lhs_nodes,
      const std::vector<IdArray>& rhs_nodes,
115
      DeviceNodeMap<IdType>* const node_maps, cudaStream_t stream) {
116
117
118
119
120
121
122
    const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();

    // unique lhs nodes
    const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());
    for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
      const IdArray& nodes = lhs_nodes[ntype];
      if (nodes->shape[0] > 0) {
123
        CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);
124
        node_maps->LhsHashTable(ntype).FillWithUnique(
125
            nodes.Ptr<IdType>(), nodes->shape[0], stream);
126
127
128
129
130
131
132
133
134
      }
    }

    // unique rhs nodes
    const int64_t rhs_num_ntypes = static_cast<int64_t>(rhs_nodes.size());
    for (int64_t ntype = 0; ntype < rhs_num_ntypes; ++ntype) {
      const IdArray& nodes = rhs_nodes[ntype];
      if (nodes->shape[0] > 0) {
        node_maps->RhsHashTable(ntype).FillWithUnique(
135
            nodes.Ptr<IdType>(), nodes->shape[0], stream);
136
137
138
139
      }
    }
  }

140
141
142
143
 private:
  IdType max_num_nodes_;
};

144
template <typename IdType>
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
struct CUDAIdsMapper {
  std::tuple<std::vector<IdArray>, std::vector<IdArray>> operator()(
      const HeteroGraphPtr& graph, bool include_rhs_in_lhs, int64_t num_ntypes,
      const DGLContext& ctx, const std::vector<int64_t>& maxNodesPerType,
      const std::vector<EdgeArray>& edge_arrays,
      const std::vector<IdArray>& src_nodes,
      const std::vector<IdArray>& rhs_nodes,
      std::vector<IdArray>* const lhs_nodes_ptr,
      std::vector<int64_t>* const num_nodes_per_type_ptr) {
    std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;
    std::vector<int64_t>& num_nodes_per_type = *num_nodes_per_type_ptr;
    const bool generate_lhs_nodes = lhs_nodes.empty();
    auto device = runtime::DeviceAPI::Get(ctx);
    cudaStream_t stream = runtime::getCurrentCUDAStream();

    // Allocate space for map creation process.
    DeviceNodeMapMaker<IdType> maker(maxNodesPerType);
    DeviceNodeMap<IdType> node_maps(maxNodesPerType, num_ntypes, ctx, stream);
163
    if (generate_lhs_nodes) {
164
165
166
167
      lhs_nodes.reserve(num_ntypes);
      for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
        lhs_nodes.emplace_back(
            NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8));
168
169
      }
    }
170
171
172

    cudaEvent_t copyEvent;
    NDArray new_len_tensor;
173
174
175
176
    // Populate the mappings.
    if (generate_lhs_nodes) {
      int64_t* count_lhs_device = static_cast<int64_t*>(
          device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes * 2));
177

178
179
180
      maker.Make(
          src_nodes, rhs_nodes, &node_maps, count_lhs_device, &lhs_nodes,
          stream);
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
      CUDA_CALL(cudaEventCreate(&copyEvent));
      if (TensorDispatcher::Global()->IsAvailable()) {
        new_len_tensor = NDArray::PinnedEmpty(
            {num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,
            DGLContext{kDGLCPU, 0});
      } else {
        // use pageable memory, it will unecessarily block but be functional
        new_len_tensor = NDArray::Empty(
            {num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,
            DGLContext{kDGLCPU, 0});
      }
      CUDA_CALL(cudaMemcpyAsync(
          new_len_tensor->data, count_lhs_device,
          sizeof(*num_nodes_per_type.data()) * num_ntypes,
          cudaMemcpyDeviceToHost, stream));
      CUDA_CALL(cudaEventRecord(copyEvent, stream));
198

199
      device->FreeWorkspace(ctx, count_lhs_device);
200
    } else {
201
      maker.Make(lhs_nodes, rhs_nodes, &node_maps, stream);
202

203
204
205
      for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
        num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0];
      }
206
    }
207
208
209
    // Map node numberings from global to local, and build pointer for CSR.
    auto ret = MapEdges(graph, edge_arrays, node_maps, stream);

210
    if (generate_lhs_nodes) {
211
212
213
214
215
      // wait for the previous copy
      CUDA_CALL(cudaEventSynchronize(copyEvent));
      CUDA_CALL(cudaEventDestroy(copyEvent));

      // Resize lhs nodes.
216
      for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
217
218
        num_nodes_per_type[ntype] =
            static_cast<int64_t*>(new_len_tensor->data)[ntype];
219
220
        lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype];
      }
221
    }
222
223

    return ret;
224
  }
225
};
226

227
228
229
230
231
232
233
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU(
    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) {
  return dgl::transform::ProcessToBlock<IdType>(
      graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes_ptr,
      CUDAIdsMapper<IdType>());
234
235
236
237
}

}  // namespace

238
239
240
// Use explicit names to get around MSVC's broken mangling that thinks the
// following two functions are the same. Using template<> fails to export the
// symbols.
241
std::tuple<HeteroGraphPtr, std::vector<IdArray>>
242
// ToBlock<kDGLCUDA, int32_t>
243
ToBlockGPU32(
244
245
    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {
246
  return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
247
248
}

249
std::tuple<HeteroGraphPtr, std::vector<IdArray>>
250
// ToBlock<kDGLCUDA, int64_t>
251
ToBlockGPU64(
252
253
    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {
254
  return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
255
256
257
258
}

}  // namespace transform
}  // namespace dgl