Commit 84b46170 authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent 014c4bae
...@@ -39,9 +39,8 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -39,9 +39,8 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
if (col_count == 0) if (col_count == 0)
continue; continue;
if (replace) { if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
for (int64_t j = 0; j < num_samples; j++) { for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset]; const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()}); const auto res = to_local_node.insert({v, samples.size()});
if (res.second) if (res.second)
...@@ -52,8 +51,9 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -52,8 +51,9 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
edges.push_back(offset); edges.push_back(offset);
} }
} }
} else if (num_samples >= col_count) { } else if (replace) {
for (int64_t offset = col_start; offset < col_end; offset++) { for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset]; const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()}); const auto res = to_local_node.insert({v, samples.size()});
if (res.second) if (res.second)
...@@ -111,14 +111,14 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -111,14 +111,14 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
} }
template <bool replace, bool directed> template <bool replace, bool directed>
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const std::vector<node_t> &node_types, hetero_sample(const vector<node_t> &node_types,
const std::vector<edge_t> &edge_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict, const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict, const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict, const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict, const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) { const int64_t num_hops) {
// Create a mapping to convert single string relations to edge type triplets: // Create a mapping to convert single string relations to edge type triplets:
...@@ -129,9 +129,9 @@ hetero_sample(const std::vector<node_t> &node_types, ...@@ -129,9 +129,9 @@ hetero_sample(const std::vector<node_t> &node_types,
// Initialize some data structures for the sampling process: // Initialize some data structures for the sampling process:
unordered_map<node_t, vector<int64_t>> samples_dict; unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict; unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
for (const auto &k : node_types) { for (const auto &node_type : node_types) {
samples_dict[k]; samples_dict[node_type];
to_local_node_dict[k]; to_local_node_dict[node_type];
} }
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict; unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
...@@ -167,7 +167,7 @@ hetero_sample(const std::vector<node_t> &node_types, ...@@ -167,7 +167,7 @@ hetero_sample(const std::vector<node_t> &node_types,
const auto &edge_type = to_edge_type[rel_type]; const auto &edge_type = to_edge_type[rel_type];
const auto &src_node_type = get<0>(edge_type); const auto &src_node_type = get<0>(edge_type);
const auto &dst_node_type = get<2>(edge_type); const auto &dst_node_type = get<2>(edge_type);
const auto &num_samples = kv.value()[ell]; const auto num_samples = kv.value()[ell];
const auto &dst_samples = samples_dict.at(dst_node_type); const auto &dst_samples = samples_dict.at(dst_node_type);
auto &src_samples = samples_dict.at(src_node_type); auto &src_samples = samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type); auto &to_local_src_node = to_local_node_dict.at(src_node_type);
...@@ -190,9 +190,8 @@ hetero_sample(const std::vector<node_t> &node_types, ...@@ -190,9 +190,8 @@ hetero_sample(const std::vector<node_t> &node_types,
if (col_count == 0) if (col_count == 0)
continue; continue;
if (replace) { if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
for (int64_t j = 0; j < num_samples; j++) { for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset]; const int64_t &v = row_data[offset];
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) if (res.second)
...@@ -203,8 +202,9 @@ hetero_sample(const std::vector<node_t> &node_types, ...@@ -203,8 +202,9 @@ hetero_sample(const std::vector<node_t> &node_types,
edges.push_back(offset); edges.push_back(offset);
} }
} }
} else if (num_samples >= col_count) { } else if (replace) {
for (int64_t offset = col_start; offset < col_end; offset++) { for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset]; const int64_t &v = row_data[offset];
const auto res = to_local_src_node.insert({v, src_samples.size()}); const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second) if (res.second)
...@@ -302,15 +302,14 @@ neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -302,15 +302,14 @@ neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
} }
} }
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample_cpu( hetero_neighbor_sample_cpu(
const std::vector<node_t> &node_types, const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const std::vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict, const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict, const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict, const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict, const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) { const int64_t num_hops, const bool replace, const bool directed) {
if (replace && directed) { if (replace && directed) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment