python_binding.cc 4.7 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
/**
 *  Copyright (c) 2023 by Contributors
 * @file python_binding.cc
 * @brief Graph bolt library Python binding.
 */

8
#include <graphbolt/fused_csc_sampling_graph.h>
9
#include <graphbolt/isin.h>
10
#include <graphbolt/serialize.h>
11
#include <graphbolt/unique_and_compact.h>
12

13
#ifdef GRAPHBOLT_USE_CUDA
sangwzh's avatar
sangwzh committed
14
#include "cuda/max_uva_threads.h"
15
#endif
16

17
#include "./cnumpy.h"
18
#include "./expand_indptr.h"
19
#include "./index_select.h"
20
#include "./random.h"
21
22


23
#ifdef GRAPHBOLT_USE_CUDA
sangwzh's avatar
sangwzh committed
24
#include "cuda/gpu_cache.h"
25
26
#endif

27
28
29
30
namespace graphbolt {
namespace sampling {

TORCH_LIBRARY(graphbolt, m) {
31
  m.class_<FusedSampledSubgraph>("FusedSampledSubgraph")
32
      .def(torch::init<>())
33
34
      .def_readwrite("indptr", &FusedSampledSubgraph::indptr)
      .def_readwrite("indices", &FusedSampledSubgraph::indices)
35
      .def_readwrite(
36
          "original_row_node_ids", &FusedSampledSubgraph::original_row_node_ids)
37
      .def_readwrite(
38
          "original_column_node_ids",
39
40
41
          &FusedSampledSubgraph::original_column_node_ids)
      .def_readwrite(
          "original_edge_ids", &FusedSampledSubgraph::original_edge_ids)
42
43
      .def_readwrite("type_per_edge", &FusedSampledSubgraph::type_per_edge)
      .def_readwrite("etype_offsets", &FusedSampledSubgraph::etype_offsets);
44
45
  m.class_<storage::OnDiskNpyArray>("OnDiskNpyArray")
      .def("index_select", &storage::OnDiskNpyArray::IndexSelect);
46
47
48
49
50
51
52
  m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
      .def("num_nodes", &FusedCSCSamplingGraph::NumNodes)
      .def("num_edges", &FusedCSCSamplingGraph::NumEdges)
      .def("csc_indptr", &FusedCSCSamplingGraph::CSCIndptr)
      .def("indices", &FusedCSCSamplingGraph::Indices)
      .def("node_type_offset", &FusedCSCSamplingGraph::NodeTypeOffset)
      .def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
53
54
      .def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID)
      .def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID)
55
      .def("node_attributes", &FusedCSCSamplingGraph::NodeAttributes)
56
57
58
59
60
      .def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
      .def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
      .def("set_indices", &FusedCSCSamplingGraph::SetIndices)
      .def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset)
      .def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
61
62
      .def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID)
      .def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID)
63
      .def("set_node_attributes", &FusedCSCSamplingGraph::SetNodeAttributes)
64
65
66
      .def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
      .def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
      .def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
67
68
69
      .def(
          "temporal_sample_neighbors",
          &FusedCSCSamplingGraph::TemporalSampleNeighbors)
70
      .def("copy_to_shared_memory", &FusedCSCSamplingGraph::CopyToSharedMemory)
71
72
      .def_pickle(
          // __getstate__
73
          [](const c10::intrusive_ptr<FusedCSCSamplingGraph>& self)
74
75
76
77
78
79
              -> torch::Dict<
                  std::string, torch::Dict<std::string, torch::Tensor>> {
            return self->GetState();
          },
          // __setstate__
          [](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
80
81
                 state) -> c10::intrusive_ptr<FusedCSCSamplingGraph> {
            auto g = c10::make_intrusive<FusedCSCSamplingGraph>();
82
83
84
            g->SetState(state);
            return g;
          });
85
86
87
88
89
90
#ifdef GRAPHBOLT_USE_CUDA
  m.class_<cuda::GpuCache>("GpuCache")
      .def("query", &cuda::GpuCache::Query)
      .def("replace", &cuda::GpuCache::Replace);
  m.def("gpu_cache", &cuda::GpuCache::Create);
#endif
91
  m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create);
92
93
  m.def(
      "load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
94
  m.def("unique_and_compact", &UniqueAndCompact);
95
  m.def("unique_and_compact_batched", &UniqueAndCompactBatched);
96
  m.def("isin", &IsIn);
97
  m.def("index_select", &ops::IndexSelect);
98
  m.def("index_select_csc", &ops::IndexSelectCSC);
99
  m.def("ondisk_npy_array", &storage::OnDiskNpyArray::Create);
100
  m.def("set_seed", &RandomEngine::SetManualSeed);
101
102
103
#ifdef GRAPHBOLT_USE_CUDA
  m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
#endif
104
105
106
107
108
109
110
111
112
113
114
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
  m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base");
#endif
  m.def(
      "expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, "
      "SymInt? output_size) -> Tensor"
#ifdef HAS_PT2_COMPLIANT_TAG
      ,
      {at::Tag::pt2_compliant_tag}
#endif
  );
115
116
117
118
}

}  // namespace sampling
}  // namespace graphbolt