Commit 84b46170 authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent 014c4bae
......@@ -39,9 +39,8 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
if (col_count == 0)
continue;
if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
......@@ -52,8 +51,9 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
edges.push_back(offset);
}
}
} else if (num_samples >= col_count) {
for (int64_t offset = col_start; offset < col_end; offset++) {
} else if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
......@@ -111,14 +111,14 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
}
template <bool replace, bool directed>
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types,
const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
// Create a mapping to convert single string relations to edge type triplets:
......@@ -129,9 +129,9 @@ hetero_sample(const std::vector<node_t> &node_types,
// Initialize some data structures for the sampling process:
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
for (const auto &k : node_types) {
samples_dict[k];
to_local_node_dict[k];
for (const auto &node_type : node_types) {
samples_dict[node_type];
to_local_node_dict[node_type];
}
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
......@@ -167,7 +167,7 @@ hetero_sample(const std::vector<node_t> &node_types,
const auto &edge_type = to_edge_type[rel_type];
const auto &src_node_type = get<0>(edge_type);
const auto &dst_node_type = get<2>(edge_type);
const auto &num_samples = kv.value()[ell];
const auto num_samples = kv.value()[ell];
const auto &dst_samples = samples_dict.at(dst_node_type);
auto &src_samples = samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
......@@ -190,9 +190,8 @@ hetero_sample(const std::vector<node_t> &node_types,
if (col_count == 0)
continue;
if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset];
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
......@@ -203,8 +202,9 @@ hetero_sample(const std::vector<node_t> &node_types,
edges.push_back(offset);
}
}
} else if (num_samples >= col_count) {
for (int64_t offset = col_start; offset < col_end; offset++) {
} else if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t &v = row_data[offset];
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
......@@ -302,15 +302,14 @@ neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
}
}
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_sample_cpu(
const std::vector<node_t> &node_types,
const std::vector<edge_t> &edge_types,
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) {
if (replace && directed) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment