hg_sample_cpu.cpp 5.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#include "hg_sample_cpu.h"

#include <unordered_map>
#include <random>
#include "utils.h"

namespace {

void update_budget(
	int64_t added_node_idx,
	const node_t &added_node_type,
	const c10::Dict<int64_t, torch::Tensor> &rowptr_store,
	const c10::Dict<int64_t, torch::Tensor> &col_store,
	const c10::Dict<int64_t, edge_t> &edge_type_idx_to_name,
	std::unordered_map<node_t, std::unordered_set<int64_t>> &sampled_nodes_store,
	std::unordered_map<node_t, std::unordered_map<int64_t, float>> *budget_store
) {
	for (const auto &i : rowptr_store) {
		const auto &edge_type_idx = i.key();
		const auto &edge_type = edge_type_idx_to_name.at(edge_type_idx);
		const auto &source_node_type = std::get<0>(edge_type);
		const auto &dest_node_type = std::get<2>(edge_type);

		// Skip processing the (rowptr, col) if the node types do not match
		if (added_node_type.compare(dest_node_type) != 0) {
			continue;
		}

		int64_t *row_ptr_raw = i.value().data_ptr<int64_t>();
		int64_t *col_raw = col_store.at(edge_type_idx).data_ptr<int64_t>();

		// Get the budget map and sampled_nodes for the source node type of the relation
		const std::unordered_set<int64_t> &sampled_nodes = sampled_nodes_store[source_node_type];
		std::unordered_map<int64_t, float> &budget = (*budget_store)[source_node_type];

		int64_t row_start_idx = row_ptr_raw[added_node_idx];
		int64_t row_end_idx = row_ptr_raw[added_node_idx + 1];
		if (row_start_idx != row_end_idx) {
			// Compute the norm of degree and update the budget for the neighbors of added_node_idx
			float norm_deg = 1 / (float)(row_end_idx - row_start_idx);
			for (int64_t j = row_start_idx; j < row_end_idx; j++) {
				if (sampled_nodes.find(col_raw[j]) == sampled_nodes.end()) {
					budget[col_raw[j]] += norm_deg;
				}
			}
		}
	}
}

// Sample n nodes according to its type budget map. The probability that node i is sampled is calculated by
// prob[i] = budget[i]^2 / l2_norm(budget)^2.
std::unordered_set<int64_t> sample_nodes(const std::unordered_map<int64_t, float> &budget, int n) {	
	// Compute the squared L2 norm	
	float norm = 0.0;
	for (const auto &i : budget) {
		norm += i.second * i.second;	
	}

	// Generate n sorted random values between 0 and norm
	std::vector<float> samples(n);
	std::uniform_real_distribution<float> dist(0.0, norm);
	std::default_random_engine gen{std::random_device{}()};
	std::generate(std::begin(samples), std::end(samples), [&]{ return dist(gen); });
	std::sort(samples.begin(), samples.end());

	// Iterate through the budget map to compute the cumulative probability cum_prob[i] for node_i. The j-th
	// sample is assigned to node_i iff cum_prob[i-1] < samples[j] < cum_prob[i]. The implementation assigns
	// two iterators on budget and samples respectively, then computes the node samples in linear time by 
	// alternatingly incrementing the two iterators based on their values.
	std::unordered_set<int64_t> sampled_nodes;
	sampled_nodes.reserve(samples.size());
	auto j = samples.begin();
	float cum_prob = 0.0;
	for (const auto &i : budget) {
		cum_prob += i.second * i.second;

		// Increment iterator j until its value is greater than the current cum_prob
		while (*j < cum_prob && j != samples.end()) {
			sampled_nodes.insert(i.first);
			j++;
		}

		// Terminate early after we complete the sampling
		if (j == samples.end()) {
			break;
		}
	}

	return sampled_nodes;
}

}  // namespace

// TODO: Add the appropriate return type
void hg_sample_cpu(
	const c10::Dict<int64_t, torch::Tensor> &rowptr_store,
	const c10::Dict<int64_t, torch::Tensor> &col_store,  
	const c10::Dict<node_t, torch::Tensor> &origin_nodes_store,
	const c10::Dict<int64_t, edge_t> &edge_type_idx_to_name,
	int n,
	int num_layers
) {
	// Verify input
	for (const auto &kv : rowptr_store) {
		CHECK_CPU(kv.value());
	}
	
	for (const auto &kv : col_store) {
		CHECK_CPU(kv.value());
	}
	
	for (const auto &kv : origin_nodes_store) {
		CHECK_CPU(kv.value());
	  	CHECK_INPUT(kv.value().dim() == 1);
	}
	
	// Initialize various data structures for the sampling process
	std::unordered_map<node_t, std::unordered_set<int64_t>> sampled_nodes_store;
	for (const auto &kv : origin_nodes_store) {
		const auto &node_type = kv.key();
		const auto &origin_nodes = kv.value();
		const int64_t *raw_origin_nodes = origin_nodes.data_ptr<int64_t>();

		// Add each origin node to the sampled_nodes_store
		for (int64_t i = 0; i < origin_nodes.numel(); i++) {
			sampled_nodes_store[node_type].insert(raw_origin_nodes[i]);
		}
	}

	std::unordered_map<node_t, std::unordered_map<int64_t, float>> budget_store;
	for (const auto &kv : origin_nodes_store) {
		const node_t &node_type = kv.key();
		const auto &origin_nodes = kv.value();
		const int64_t *raw_origin_nodes = origin_nodes.data_ptr<int64_t>();

		// Update budget for each origin node
		for (int64_t i = 0; i < origin_nodes.numel(); i++) {
			update_budget(
				raw_origin_nodes[i],
				node_type,
				rowptr_store,
				col_store,
				edge_type_idx_to_name,
				sampled_nodes_store,
				&budget_store
			);
		}
	}


	// Sampling process
	for (int l = 0; l < num_layers; l++) {	
		for (auto &i : budget_store) {
			const auto &node_type = i.first;
			auto &budget = i.second;
			auto &sampled_nodes = sampled_nodes_store[node_type];

			// Perform sampling
			std::unordered_set<int64_t> new_samples = sample_nodes(budget, n);

			// Remove sampled nodes from the budget and add them to the sampled node store
			for (const auto &sample : new_samples) {
				sampled_nodes.insert(sample);
				budget.erase(sample);
			}

			// Update the budget
			for (const auto &sample : new_samples) {
				update_budget(
					sample,
					node_type,
					rowptr_store,
					col_store,
					edge_type_idx_to_name,
					sampled_nodes_store,
					&budget_store
				);
			}
		}
	}

	// Re-index
	c10::Dict<std::string, std::vector<int64_t>> type_to_n_ids;

}