"docs/vscode:/vscode.git/clone" did not exist on "088985936518be7e25795a30d8ab33affa9db6ed"
geometry.cc 3.82 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019 by Contributors
4
5
 * @file geometry/geometry.cc
 * @brief DGL geometry utilities implementation
6
7
 */
#include <dgl/array.h>
8
#include <dgl/base_heterograph.h>
9
#include <dgl/packed_func_ext.h>
10
11
12
#include <dgl/runtime/ndarray.h>

#include "../array/check.h"
13
#include "../c_api_common.h"
sangwzh's avatar
sangwzh committed
14
#include "geometry_op.h"
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {
namespace geometry {

21
22
23
24
25
26
27
28
29
30
31
void FarthestPointSampler(
    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
    IdArray start_idx, IdArray result) {
  CHECK_EQ(array->ctx, result->ctx)
      << "Array and the result should be on the same device.";
  CHECK_EQ(array->shape[0], dist->shape[0])
      << "Shape of array and dist mismatch";
  CHECK_EQ(start_idx->shape[0], batch_size)
      << "Shape of start_idx and batch_size mismatch";
  CHECK_EQ(result->shape[0], batch_size * sample_points)
      << "Invalid shape of result";
32
33
34

  ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", {
    ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
35
36
37
38
39
      ATEN_XPU_SWITCH_CUDA(
          array->ctx.device_type, XPU, "FarthestPointSampler", {
            impl::FarthestPointSampler<XPU, FloatType, IdType>(
                array, batch_size, sample_points, dist, start_idx, result);
          });
40
41
42
43
    });
  });
}

44
45
void NeighborMatching(
    HeteroGraphPtr graph, const NDArray weight, IdArray result) {
46
  if (!aten::IsNullArray(weight)) {
47
48
49
50
51
52
53
54
    ATEN_XPU_SWITCH_CUDA(
        graph->Context().device_type, XPU, "NeighborMatching", {
          ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
            ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
              impl::WeightedNeighborMatching<XPU, FloatType, IdType>(
                  graph->GetCSRMatrix(0), weight, result);
            });
          });
55
56
        });
  } else {
57
58
59
60
61
62
    ATEN_XPU_SWITCH_CUDA(
        graph->Context().device_type, XPU, "NeighborMatching", {
          ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
            impl::NeighborMatching<XPU, IdType>(graph->GetCSRMatrix(0), result);
          });
        });
63
64
65
  }
}

66
67
68
///////////////////////// C APIs /////////////////////////

DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
69
70
71
72
73
74
75
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const NDArray data = args[0];
      const int64_t batch_size = args[1];
      const int64_t sample_points = args[2];
      NDArray dist = args[3];
      IdArray start_idx = args[4];
      IdArray result = args[5];
76

77
78
79
      FarthestPointSampler(
          data, batch_size, sample_points, dist, start_idx, result);
    });
80

81
DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching")
82
83
84
85
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef graph = args[0];
      const NDArray weight = args[1];
      IdArray result = args[2];
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
      // sanity check
      aten::CheckCtx(
          graph->Context(), {weight, result}, {"edge_weight, result"});
      aten::CheckContiguous({weight, result}, {"edge_weight", "result"});
      CHECK_EQ(graph->NumEdgeTypes(), 1)
          << "homogeneous graph has only one edge type";
      CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor.";
      auto pair = graph->meta_graph()->FindEdge(0);
      const dgl_type_t node_type = pair.first;
      CHECK_EQ(graph->NumVertices(node_type), result->shape[0])
          << "The number of nodes should be the same as the length of result "
             "tensor.";
      if (!aten::IsNullArray(weight)) {
        CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor.";
        CHECK_EQ(graph->NumEdges(0), weight->shape[0])
            << "number of edges in graph should be the same "
            << "as the length of edge weight tensor.";
      }
105

106
107
108
      // call implementation
      NeighborMatching(graph.sptr(), weight, result);
    });
109

110
111
}  // namespace geometry
}  // namespace dgl