"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0d633a42f47671e11bf0ce28c72d871cdadc1cd1"
Commit 3c259af5 authored by rusty1s's avatar rusty1s
Browse files

metis initial commit

parent e78637ea
#include "metis_cpu.h"
#include <metis.h>
#include "utils.h"
torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
int64_t nvtxs = rowptr.numel() - 1;
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
int64_t ncon = 1;
int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part_data = part.data_ptr<int64_t>();
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL, &num_parts,
NULL, NULL, NULL, &objval, part_data);
return part;
}
#pragma once
#include <torch/extension.h>
torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts);
#include <Python.h>
#include <torch/script.h>
#include "cpu/metis_cpu.h"
#include <metis.h>
#ifdef _WIN32
PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch::Tensor partition_kway(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_kway_cpu(rowptr, col, num_parts);
}
}
static auto registry = torch::RegisterOperators().op(
"torch_sparse::partition_kway", &partition_kway);
...@@ -59,6 +59,7 @@ def get_extensions(): ...@@ -59,6 +59,7 @@ def get_extensions():
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
libraries=['metis'],
) )
extensions += [extension] extensions += [extension]
......
...@@ -7,7 +7,9 @@ __version__ = '0.5.1' ...@@ -7,7 +7,9 @@ __version__ = '0.5.1'
expected_torch_version = (1, 4) expected_torch_version = (1, 4)
try: try:
for library in ['_version', '_convert', '_diag', '_spmm', '_spspmm']: for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
except OSError as e: except OSError as e:
...@@ -45,12 +47,14 @@ from .narrow import narrow, __narrow_diag__ # noqa ...@@ -45,12 +47,14 @@ from .narrow import narrow, __narrow_diag__ # noqa
from .select import select # noqa from .select import select # noqa
from .index_select import index_select, index_select_nnz # noqa from .index_select import index_select, index_select_nnz # noqa
from .masked_select import masked_select, masked_select_nnz # noqa from .masked_select import masked_select, masked_select_nnz # noqa
from .permute import permute # noqa
from .diag import remove_diag, set_diag, fill_diag # noqa from .diag import remove_diag, set_diag, fill_diag # noqa
from .add import add, add_, add_nnz, add_nnz_ # noqa from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat, cat_diag # noqa
from .metis import partition_kway # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
...@@ -71,6 +75,7 @@ __all__ = [ ...@@ -71,6 +75,7 @@ __all__ = [
'index_select_nnz', 'index_select_nnz',
'masked_select', 'masked_select',
'masked_select_nnz', 'masked_select_nnz',
'permute',
'remove_diag', 'remove_diag',
'set_diag', 'set_diag',
'fill_diag', 'fill_diag',
...@@ -89,6 +94,7 @@ __all__ = [ ...@@ -89,6 +94,7 @@ __all__ = [
'matmul', 'matmul',
'cat', 'cat',
'cat_diag', 'cat_diag',
'partition_kway',
'to_torch_sparse', 'to_torch_sparse',
'from_torch_sparse', 'from_torch_sparse',
'to_scipy', 'to_scipy',
......
from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
@torch.jit.script
def partition_kway(
src: SparseTensor,
num_parts: int) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition_kway(rowptr, col, num_parts)
cluster = cluster.to(src.device())
cluster, perm = cluster.sort()
out = permute(src, perm)
partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
return out, partptr, perm
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
assert src.is_symmetric()
row, col, value = src.coo()
row = perm[row]
col = perm[col]
if value is not None:
value = value[row]
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount[perm]
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount[perm]
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=False)
return src.from_storage(storage)
SparseTensor.permute = lambda self, perm: permute(self, perm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment