geometry.cc 3.52 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file geometry/geometry.cc
 * \brief DGL geometry utilities implementation
 */
#include <dgl/array.h>
#include <dgl/runtime/ndarray.h>
8
#include <dgl/base_heterograph.h>
9
10
#include "../c_api_common.h"
#include "./geometry_op.h"
11
#include "../array/check.h"
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

using namespace dgl::runtime;

namespace dgl {
namespace geometry {

void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
    NDArray dist, IdArray start_idx, IdArray result) {

  CHECK_EQ(array->ctx, result->ctx) << "Array and the result should be on the same device.";
  CHECK_EQ(array->shape[0], dist->shape[0]) << "Shape of array and dist mismatch";
  CHECK_EQ(start_idx->shape[0], batch_size) << "Shape of start_idx and batch_size mismatch";
  CHECK_EQ(result->shape[0], batch_size * sample_points) << "Invalid shape of result";

  ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", {
    ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
      ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "FarthestPointSampler", {
        impl::FarthestPointSampler<XPU, FloatType, IdType>(
            array, batch_size, sample_points, dist, start_idx, result);
      });
    });
  });
}

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result) {
  if (!aten::IsNullArray(weight)) {
    ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", {
      ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
        ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
          impl::WeightedNeighborMatching<XPU, FloatType, IdType>(
              graph->GetCSRMatrix(0), weight, result);
        });
      });
    });
  } else {
    ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", {
      ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
        impl::NeighborMatching<XPU, IdType>(
            graph->GetCSRMatrix(0), result);
      });
    });
  }
}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
///////////////////////// C APIs /////////////////////////

DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    const NDArray data = args[0];
    const int64_t batch_size = args[1];
    const int64_t sample_points = args[2];
    NDArray dist = args[3];
    IdArray start_idx = args[4];
    IdArray result = args[5];

    FarthestPointSampler(data, batch_size, sample_points, dist, start_idx, result);
  });

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef graph = args[0];
    const NDArray weight = args[1];
    IdArray result = args[2];

    // sanity check
    aten::CheckCtx(graph->Context(), {weight, result}, {"edge_weight, result"});
    aten::CheckContiguous({weight, result}, {"edge_weight", "result"});
    CHECK_EQ(graph->NumEdgeTypes(), 1) << "homogeneous graph has only one edge type";
    CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor.";
    auto pair = graph->meta_graph()->FindEdge(0);
    const dgl_type_t node_type = pair.first;
    CHECK_EQ(graph->NumVertices(node_type), result->shape[0])
      << "The number of nodes should be the same as the length of result tensor.";
    if (!aten::IsNullArray(weight)) {
      CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor.";
      CHECK_EQ(graph->NumEdges(0), weight->shape[0])
        << "number of edges in graph should be the same "
        << "as the length of edge weight tensor.";
    }

    // call implementation
    NeighborMatching(graph.sptr(), weight, result);
  });

96
97
}  // namespace geometry
}  // namespace dgl