geometry.cc 3.74 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file geometry/geometry.cc
 * \brief DGL geometry utilities implementation
 */
#include <dgl/array.h>
7
#include <dgl/base_heterograph.h>
8
9
10
#include <dgl/runtime/ndarray.h>

#include "../array/check.h"
11
12
13
14
15
16
17
18
#include "../c_api_common.h"
#include "./geometry_op.h"

using namespace dgl::runtime;

namespace dgl {
namespace geometry {

19
20
21
22
23
24
25
26
27
28
29
void FarthestPointSampler(
    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
    IdArray start_idx, IdArray result) {
  CHECK_EQ(array->ctx, result->ctx)
      << "Array and the result should be on the same device.";
  CHECK_EQ(array->shape[0], dist->shape[0])
      << "Shape of array and dist mismatch";
  CHECK_EQ(start_idx->shape[0], batch_size)
      << "Shape of start_idx and batch_size mismatch";
  CHECK_EQ(result->shape[0], batch_size * sample_points)
      << "Invalid shape of result";
30
31
32

  ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", {
    ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
33
34
35
36
37
      ATEN_XPU_SWITCH_CUDA(
          array->ctx.device_type, XPU, "FarthestPointSampler", {
            impl::FarthestPointSampler<XPU, FloatType, IdType>(
                array, batch_size, sample_points, dist, start_idx, result);
          });
38
39
40
41
    });
  });
}

42
43
void NeighborMatching(
    HeteroGraphPtr graph, const NDArray weight, IdArray result) {
44
  if (!aten::IsNullArray(weight)) {
45
46
47
48
49
50
51
52
    ATEN_XPU_SWITCH_CUDA(
        graph->Context().device_type, XPU, "NeighborMatching", {
          ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
            ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
              impl::WeightedNeighborMatching<XPU, FloatType, IdType>(
                  graph->GetCSRMatrix(0), weight, result);
            });
          });
53
54
        });
  } else {
55
56
57
58
59
60
    ATEN_XPU_SWITCH_CUDA(
        graph->Context().device_type, XPU, "NeighborMatching", {
          ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
            impl::NeighborMatching<XPU, IdType>(graph->GetCSRMatrix(0), result);
          });
        });
61
62
63
  }
}

64
65
66
///////////////////////// C APIs /////////////////////////

DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
67
68
69
70
71
72
73
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const NDArray data = args[0];
      const int64_t batch_size = args[1];
      const int64_t sample_points = args[2];
      NDArray dist = args[3];
      IdArray start_idx = args[4];
      IdArray result = args[5];
74

75
76
77
      FarthestPointSampler(
          data, batch_size, sample_points, dist, start_idx, result);
    });
78

79
DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching")
80
81
82
83
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef graph = args[0];
      const NDArray weight = args[1];
      IdArray result = args[2];
84

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
      // sanity check
      aten::CheckCtx(
          graph->Context(), {weight, result}, {"edge_weight, result"});
      aten::CheckContiguous({weight, result}, {"edge_weight", "result"});
      CHECK_EQ(graph->NumEdgeTypes(), 1)
          << "homogeneous graph has only one edge type";
      CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor.";
      auto pair = graph->meta_graph()->FindEdge(0);
      const dgl_type_t node_type = pair.first;
      CHECK_EQ(graph->NumVertices(node_type), result->shape[0])
          << "The number of nodes should be the same as the length of result "
             "tensor.";
      if (!aten::IsNullArray(weight)) {
        CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor.";
        CHECK_EQ(graph->NumEdges(0), weight->shape[0])
            << "number of edges in graph should be the same "
            << "as the length of edge weight tensor.";
      }
103

104
105
106
      // call implementation
      NeighborMatching(graph.sptr(), weight, result);
    });
107

108
109
}  // namespace geometry
}  // namespace dgl