Commit 645f3ddf authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 59121c84
import torch
from torch_cluster import random_cluster
# import torch
# from torch_cluster import serial_cluster
def test_random():
edge_index = torch.LongTensor([[0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6],
[2, 3, 6, 5, 0, 0, 4, 5, 3, 1, 3, 6, 0, 3]])
def test_serial():
pass
# ed_index = torch.LongTensor([[0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6],
# [2, 3, 6, 5, 0, 0, 4, 5, 3, 1, 3, 6, 0, 3]])
# edge_attr = torch.Tensor([2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2])
rid = torch.arange(edge_index.max() + 1, out=edge_index.new())
# rid = torch.arange(edge_index.max() + 1, out=edge_index.new())
# output = random_cluster(edge_index, rid=rid, perm_edges=False)
# expected_output = [0, 1, 2, 0, 3, 1, 4]
......
from .functions.grid import sparse_grid_cluster, dense_grid_cluster
from .functions.random import random_cluster
from .functions.serial import serial_cluster
__version__ = '0.2.6'
__all__ = [
'sparse_grid_cluster', 'dense_grid_cluster', 'random_cluster',
'sparse_grid_cluster', 'dense_grid_cluster', 'serial_cluster',
'__version__'
]
......@@ -2,7 +2,8 @@ from __future__ import division
import torch
from .utils import get_dynamic_func, consecutive
from .utils.ffi import get_dynamic_func
from .utils.consecutive import consecutive
def _preprocess(position, size, batch=None, start=None):
......
from .utils import get_func, consecutive
from .degree import node_degree
from .permute import permute
from .utils.permute import random_permute
from .utils.degree import node_degree
from .utils.ffi import get_func
from .utils.consecutive import consecutive
def random_cluster(edge_index, batch=None, num_nodes=None):
def serial_cluster(edge_index, batch=None, num_nodes=None):
num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
row, col = permute(edge_index, num_nodes)
row, col = random_permute(edge_index, num_nodes)
degree = node_degree(row, num_nodes, out=row.new())
cluster = edge_index.new(num_nodes).fill_(-1)
......
import torch
from torch_unique import unique
from .._ext import ffi
def get_func(name, tensor):
cuda = '_cuda' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}{}'.format(name, cuda))
def get_dynamic_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
def get_type(max, cuda):
if max <= 255:
......
from ..._ext import ffi
def get_func(name, tensor):
cuda = '_cuda' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}{}'.format(name, cuda))
def get_dynamic_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment