python_binding.cc 3.08 KB
Newer Older
1
2
3
4
5
6
/**
 *  Copyright (c) 2023 by Contributors
 * @file python_binding.cc
 * @brief Graph bolt library Python binding.
 */

7
#include <graphbolt/fused_csc_sampling_graph.h>
8
#include <graphbolt/isin.h>
9
#include <graphbolt/serialize.h>
10
#include <graphbolt/unique_and_compact.h>
11

12
13
#include "./index_select.h"

14
15
16
17
namespace graphbolt {
namespace sampling {

TORCH_LIBRARY(graphbolt, m) {
18
19
20
21
22
  m.class_<SampledSubgraph>("SampledSubgraph")
      .def(torch::init<>())
      .def_readwrite("indptr", &SampledSubgraph::indptr)
      .def_readwrite("indices", &SampledSubgraph::indices)
      .def_readwrite(
23
          "original_row_node_ids", &SampledSubgraph::original_row_node_ids)
24
      .def_readwrite(
25
26
27
          "original_column_node_ids",
          &SampledSubgraph::original_column_node_ids)
      .def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids)
28
      .def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
      .def("num_nodes", &FusedCSCSamplingGraph::NumNodes)
      .def("num_edges", &FusedCSCSamplingGraph::NumEdges)
      .def("csc_indptr", &FusedCSCSamplingGraph::CSCIndptr)
      .def("indices", &FusedCSCSamplingGraph::Indices)
      .def("node_type_offset", &FusedCSCSamplingGraph::NodeTypeOffset)
      .def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
      .def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
      .def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
      .def("set_indices", &FusedCSCSamplingGraph::SetIndices)
      .def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset)
      .def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
      .def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
      .def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
      .def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
44
45
      .def(
          "sample_negative_edges_uniform",
46
47
          &FusedCSCSamplingGraph::SampleNegativeEdgesUniform)
      .def("copy_to_shared_memory", &FusedCSCSamplingGraph::CopyToSharedMemory)
48
49
      .def_pickle(
          // __getstate__
50
          [](const c10::intrusive_ptr<FusedCSCSamplingGraph>& self)
51
52
53
54
55
56
              -> torch::Dict<
                  std::string, torch::Dict<std::string, torch::Tensor>> {
            return self->GetState();
          },
          // __setstate__
          [](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
57
58
                 state) -> c10::intrusive_ptr<FusedCSCSamplingGraph> {
            auto g = c10::make_intrusive<FusedCSCSamplingGraph>();
59
60
61
            g->SetState(state);
            return g;
          });
62
63
64
65
66
  m.def("from_fused_csc", &FusedCSCSamplingGraph::FromCSC);
  m.def("load_fused_csc_sampling_graph", &LoadFusedCSCSamplingGraph);
  m.def("save_fused_csc_sampling_graph", &SaveFusedCSCSamplingGraph);
  m.def(
      "load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
67
  m.def("unique_and_compact", &UniqueAndCompact);
68
  m.def("isin", &IsIn);
69
  m.def("index_select", &ops::IndexSelect);
70
71
72
73
}

}  // namespace sampling
}  // namespace graphbolt