Commit 48eed837 authored by rusty1s's avatar rusty1s
Browse files

backward compatibility with torch-sparse==0.6.9

parent d42a18a7
...@@ -12,6 +12,21 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; } ...@@ -12,6 +12,21 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
#endif #endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, optional_value, nullptr, num_parts,
recursive);
}
}
torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight, torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive) { int64_t num_parts, bool recursive) {
...@@ -46,4 +61,5 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col, ...@@ -46,4 +61,5 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
.op("torch_sparse::partition", &partition) .op("torch_sparse::partition", &partition)
.op("torch_sparse::partition2", &partition2)
.op("torch_sparse::mt_partition", &mt_partition); .op("torch_sparse::mt_partition", &mt_partition);
...@@ -11,6 +11,11 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col, ...@@ -11,6 +11,11 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive); int64_t num_parts, bool recursive);
torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive);
torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col, torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive); int64_t num_parts, bool recursive);
......
...@@ -46,8 +46,11 @@ def partition( ...@@ -46,8 +46,11 @@ def partition(
node_weight = node_weight.view(-1).detach().cpu() node_weight = node_weight.view(-1).detach().cpu()
if node_weight.is_floating_point(): if node_weight.is_floating_point():
node_weight = weight2metis(node_weight) node_weight = weight2metis(node_weight)
cluster = torch.ops.torch_sparse.partition2(rowptr, col, value,
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, node_weight, node_weight, num_parts,
recursive)
else:
cluster = torch.ops.torch_sparse.partition(rowptr, col, value,
num_parts, recursive) num_parts, recursive)
cluster = cluster.to(src.device()) cluster = cluster.to(src.device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment