hgt_sample_cpu.h 1.45 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// #pragma once

// #include <torch/extension.h>

// // Node type is a string and the edge type is a triplet of string
// representing
// // (source_node_type, relation_type, dest_node_type).
// typedef std::string node_t;
// typedef std::tuple<std::string, std::string, std::string> edge_t;

// // As of PyTorch 1.9.0, c10::Dict does not support tuples or complex data
// type as key type. We work around this
// // by representing edge types using a single int64_t and a c10::Dict that
// maps the int64_t index to edge_t. void hg_sample_cpu( 	const
// c10::Dict<int64_t, torch::Tensor> &rowptr_store, 	const c10::Dict<int64_t,
// torch::Tensor> &col_store, 	const c10::Dict<node_t, torch::Tensor>
// &origin_nodes_store, 	const c10::Dict<int64_t, edge_t>
// &edge_type_idx_to_name, 	int n, 	int num_layers
// );
//
#pragma once

#include <torch/extension.h>

typedef std::string node_t;
typedef std::string rel_t;
typedef std::tuple<std::string, std::string, std::string> edge_t;

const std::string delim = "__";

std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
           c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
               const c10::Dict<rel_t, torch::Tensor> &col_dict,
               const c10::Dict<node_t, torch::Tensor> &input_node_dict,
               const c10::Dict<node_t, std::vector<int64_t>> &num_samples_dict,
               int64_t num_hops);