python_binding.cc 2.48 KB
Newer Older
1
2
3
4
5
6
7
/**
 *  Copyright (c) 2023 by Contributors
 * @file python_binding.cc
 * @brief Graph bolt library Python binding.
 */

#include <graphbolt/csc_sampling_graph.h>
8
#include <graphbolt/serialize.h>
9
#include <graphbolt/unique_and_compact.h>
10
11
12
13
14

namespace graphbolt {
namespace sampling {

TORCH_LIBRARY(graphbolt, m) {
15
16
17
18
19
  m.class_<SampledSubgraph>("SampledSubgraph")
      .def(torch::init<>())
      .def_readwrite("indptr", &SampledSubgraph::indptr)
      .def_readwrite("indices", &SampledSubgraph::indices)
      .def_readwrite(
20
          "original_row_node_ids", &SampledSubgraph::original_row_node_ids)
21
      .def_readwrite(
22
23
24
          "original_column_node_ids",
          &SampledSubgraph::original_column_node_ids)
      .def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids)
25
      .def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
26
27
28
29
30
31
  m.class_<CSCSamplingGraph>("CSCSamplingGraph")
      .def("num_nodes", &CSCSamplingGraph::NumNodes)
      .def("num_edges", &CSCSamplingGraph::NumEdges)
      .def("csc_indptr", &CSCSamplingGraph::CSCIndptr)
      .def("indices", &CSCSamplingGraph::Indices)
      .def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
32
      .def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
33
      .def("edge_attributes", &CSCSamplingGraph::EdgeAttributes)
34
      .def("in_subgraph", &CSCSamplingGraph::InSubgraph)
35
      .def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors)
36
37
38
      .def(
          "sample_negative_edges_uniform",
          &CSCSamplingGraph::SampleNegativeEdgesUniform)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
      .def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory)
      .def_pickle(
          // __getstate__
          [](const c10::intrusive_ptr<CSCSamplingGraph>& self)
              -> torch::Dict<
                  std::string, torch::Dict<std::string, torch::Tensor>> {
            return self->GetState();
          },
          // __setstate__
          [](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
                 state) -> c10::intrusive_ptr<CSCSamplingGraph> {
            auto g = c10::make_intrusive<CSCSamplingGraph>();
            g->SetState(state);
            return g;
          });
54
  m.def("from_csc", &CSCSamplingGraph::FromCSC);
55
56
  m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
  m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
57
  m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory);
58
  m.def("unique_and_compact", &UniqueAndCompact);
59
60
61
62
}

}  // namespace sampling
}  // namespace graphbolt