Commit 288cfd44 authored by rusty1s's avatar rusty1s
Browse files

add bipartite flag

parent c493caaf
...@@ -46,7 +46,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>, ...@@ -46,7 +46,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor> torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col, relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::Tensor idx) { torch::Tensor idx, bool bipartite) {
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
...@@ -131,9 +131,10 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -131,9 +131,10 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
} }
} }
out_rowptr = if (bipartite)
torch::cat({out_rowptr, torch::full({(int64_t)n_ids.size()}, out_rowptr = torch::cat(
out_col.numel(), rowptr.options())}); {out_rowptr, torch::full({(int64_t)n_ids.size()}, out_col.numel(),
rowptr.options())});
idx = torch::cat({idx, torch::from_blob(n_ids.data(), {(int64_t)n_ids.size()}, idx = torch::cat({idx, torch::from_blob(n_ids.data(), {(int64_t)n_ids.size()},
idx.options())}); idx.options())});
......
...@@ -9,4 +9,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>, ...@@ -9,4 +9,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor> torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col, relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::Tensor idx); torch::Tensor idx, bool bipartite);
...@@ -24,7 +24,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>, ...@@ -24,7 +24,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor> torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col, relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::Tensor idx) { torch::Tensor idx, bool bipartite) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported"); AT_ERROR("No CUDA version supported");
...@@ -32,7 +32,7 @@ relabel_one_hop(torch::Tensor rowptr, torch::Tensor col, ...@@ -32,7 +32,7 @@ relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
return relabel_one_hop_cpu(rowptr, col, optional_value, idx); return relabel_one_hop_cpu(rowptr, col, optional_value, idx, bipartite);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment