#ifdef WITH_PYTHON #include #endif #include #include "cpu/neighbor_sample_cpu.h" #ifdef _WIN32 #ifdef WITH_PYTHON #ifdef WITH_CUDA PyMODINIT_FUNC PyInit__neighbor_sample_cuda(void) { return NULL; } #else PyMODINIT_FUNC PyInit__neighbor_sample_cpu(void) { return NULL; } #endif #endif #endif // Returns 'output_node', 'row', 'col', 'output_edge' SPARSE_API std::tuple neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row, const torch::Tensor &input_node, const std::vector num_neighbors, const bool replace, const bool directed) { return neighbor_sample_cpu(colptr, row, input_node, num_neighbors, replace, directed); } SPARSE_API std::tuple, c10::Dict, c10::Dict, c10::Dict> hetero_neighbor_sample( const std::vector &node_types, const std::vector &edge_types, const c10::Dict &colptr_dict, const c10::Dict &row_dict, const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const int64_t num_hops, const bool replace, const bool directed) { return hetero_neighbor_sample_cpu( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops, replace, directed); } std::tuple, c10::Dict, c10::Dict, c10::Dict> hetero_temporal_neighbor_sample( const std::vector &node_types, const std::vector &edge_types, const c10::Dict &colptr_dict, const c10::Dict &row_dict, const c10::Dict &input_node_dict, const c10::Dict> &num_neighbors_dict, const c10::Dict &node_time_dict, const int64_t num_hops, const bool replace, const bool directed) { return hetero_temporal_neighbor_sample_cpu( node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, node_time_dict, num_hops, replace, directed); } static auto registry = torch::RegisterOperators() .op("torch_sparse::neighbor_sample", &neighbor_sample) .op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample) .op("torch_sparse::hetero_temporal_neighbor_sample", &hetero_temporal_neighbor_sample);