knn.cc 2.41 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/transform/knn.cc
 * @brief k-nearest-neighbor (KNN) interface
5
6
7
 */

#include "knn.h"
8
9
10
11

#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>

12
13
14
15
16
17
18
#include "../../array/check.h"

using namespace dgl::runtime;
namespace dgl {
namespace transform {

DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
19
20
21
22
23
24
25
26
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const NDArray data_points = args[0];
      const IdArray data_offsets = args[1];
      const NDArray query_points = args[2];
      const IdArray query_offsets = args[3];
      const int k = args[4];
      IdArray result = args[5];
      const std::string algorithm = args[6];
27

28
29
30
31
32
33
34
      aten::CheckContiguous(
          {data_points, data_offsets, query_points, query_offsets, result},
          {"data_points", "data_offsets", "query_points", "query_offsets",
           "result"});
      aten::CheckCtx(
          data_points->ctx, {data_offsets, query_points, query_offsets, result},
          {"data_offsets", "query_points", "query_offsets", "result"});
35

36
37
38
39
40
41
42
      ATEN_XPU_SWITCH_CUDA(data_points->ctx.device_type, XPU, "KNN", {
        ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, "data_points", {
          ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
            KNN<XPU, FloatType, IdType>(
                data_points, data_offsets, query_points, query_offsets, k,
                result, algorithm);
          });
43
44
45
46
        });
      });
    });

47
DGL_REGISTER_GLOBAL("transform._CAPI_DGLNNDescent")
48
49
50
51
52
53
54
55
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const NDArray points = args[0];
      const IdArray offsets = args[1];
      const IdArray result = args[2];
      const int k = args[3];
      const int num_iters = args[4];
      const int num_candidates = args[5];
      const double delta = args[6];
56

57
58
59
60
61
      aten::CheckContiguous(
          {points, offsets, result}, {"points", "offsets", "result"});
      aten::CheckCtx(
          points->ctx, {points, offsets, result},
          {"points", "offsets", "result"});
62

63
64
65
66
67
68
      ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, "NNDescent", {
        ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, "points", {
          ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
            NNDescent<XPU, FloatType, IdType>(
                points, offsets, result, k, num_iters, num_candidates, delta);
          });
69
70
71
72
        });
      });
    });

73
74
}  // namespace transform
}  // namespace dgl