Commit da65da5b authored by rusty1s's avatar rusty1s
Browse files

changed arg order

parent 2a8339db
...@@ -27,6 +27,9 @@ def test_permute_cpu(): ...@@ -27,6 +27,9 @@ def test_permute_cpu():
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_sort_gpu(): # pragma: no cover def test_sort_gpu(): # pragma: no cover
# Note that `sort` is not stable on the GPU, so it does not preserve the
# relative ordering of equivalent row elements. Thus, the expected column
# vector differs from the CPU version (which is stable).
row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3]) row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2]) col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
row, col = sort(row, col) row, col = sort(row, col)
...@@ -38,6 +41,7 @@ def test_sort_gpu(): # pragma: no cover ...@@ -38,6 +41,7 @@ def test_sort_gpu(): # pragma: no cover
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_permute_gpu(): # pragma: no cover def test_permute_gpu(): # pragma: no cover
# Equivalent to `sort`, `permute` is not stable on the GPU (see above).
row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3]) row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2]) col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
node_rid = torch.cuda.LongTensor([2, 1, 3, 0]) node_rid = torch.cuda.LongTensor([2, 1, 3, 0])
......
...@@ -2,7 +2,7 @@ from __future__ import division ...@@ -2,7 +2,7 @@ from __future__ import division
import torch import torch
from .utils.ffi import get_typed_func from .utils.ffi import _get_typed_func
from .utils.consecutive import consecutive from .utils.consecutive import consecutive
...@@ -70,7 +70,7 @@ def _grid_cluster(position, size, cluster_size): ...@@ -70,7 +70,7 @@ def _grid_cluster(position, size, cluster_size):
cluster = cluster_size.new(torch.Size(list(position.size())[:-1])) cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
cluster = cluster.unsqueeze(dim=-1) cluster = cluster.unsqueeze(dim=-1)
func = get_typed_func('grid', position) func = _get_typed_func('grid', position)
func(C, cluster, position, size, cluster_size) func(C, cluster, position, size, cluster_size)
cluster = cluster.squeeze(dim=-1) cluster = cluster.squeeze(dim=-1)
......
from .utils.permute import permute from .utils.permute import permute
from .utils.degree import node_degree from .utils.degree import node_degree
from .utils.ffi import get_func from .utils.ffi import _get_func
from .utils.consecutive import consecutive from .utils.consecutive import consecutive
...@@ -11,7 +11,7 @@ def serial_cluster(edge_index, batch=None, num_nodes=None): ...@@ -11,7 +11,7 @@ def serial_cluster(edge_index, batch=None, num_nodes=None):
degree = node_degree(row, num_nodes, out=row.new()) degree = node_degree(row, num_nodes, out=row.new())
cluster = edge_index.new(num_nodes).fill_(-1) cluster = edge_index.new(num_nodes).fill_(-1)
func = get_func('random', cluster) func = _get_func('random', cluster)
func(cluster, row, col, degree) func(cluster, row, col, degree)
cluster, u = consecutive(cluster) cluster, u = consecutive(cluster)
......
from ..._ext import ffi from ..._ext import ffi
def get_func(name, tensor): def _get_func(name, tensor):
cuda = '_cuda' if tensor.is_cuda else '' cuda = '_cuda' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}{}'.format(name, cuda)) return getattr(ffi, 'cluster_{}{}'.format(name, cuda))
def get_typed_func(name, tensor): def _get_typed_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '') typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else '' cuda = 'cuda_' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename)) return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
def ffi_serial(output, row, col, degree, weight=None):
if weight is None:
func = _get_func('serial', row)
func(output, row, col, degree)
return output
else:
func = _get_typed_func('serial', weight)
func(output, row, col, degree, weight)
return output
def ffi_grid(C, output, position, size, count):
func = _get_typed_func('grid', position)
func(C, output, position, size, count)
return output
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define THC_GENERIC_FILE "generic/serial.cu" #define THC_GENERIC_FILE "generic/serial.cu"
#else #else
void cluster_(serial)(THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCTensor *weight, THCudaLongTensor *degree) { void cluster_(serial)(THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCTensor *weight) {
} }
#endif #endif
......
...@@ -4,13 +4,13 @@ extern "C" { ...@@ -4,13 +4,13 @@ extern "C" {
void cluster_serial_kernel(THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree); void cluster_serial_kernel(THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree);
void cluster_serial_kernel_Float (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaTensor *weight, THCudaLongTensor *degree); void cluster_serial_kernel_Float (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaTensor *weight);
void cluster_serial_kernel_Double(THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaDoubleTensor *weight, THCudaLongTensor *degree); void cluster_serial_kernel_Double(THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaDoubleTensor *weight);
void cluster_serial_kernel_Byte (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaByteTensor *weight, THCudaLongTensor *degree); void cluster_serial_kernel_Byte (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaByteTensor *weight);
void cluster_serial_kernel_Char (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaCharTensor *weight, THCudaLongTensor *degree); void cluster_serial_kernel_Char (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaCharTensor *weight);
void cluster_serial_kernel_Short (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaShortTensor *weight, THCudaLongTensor *degree); void cluster_serial_kernel_Short (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaShortTensor *weight);
void cluster_serial_kernel_Int (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaIntTensor *weight, THCudaLongTensor *degree); void cluster_serial_kernel_Int (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaIntTensor *weight);
void cluster_serial_kernel_Long (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *weight, THCudaLongTensor *degree); void cluster_serial_kernel_Long (THCState *state, THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaLongTensor *weight);
#ifdef __cplusplus #ifdef __cplusplus
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define TH_GENERIC_FILE "generic/serial_cpu.c" #define TH_GENERIC_FILE "generic/serial_cpu.c"
#else #else
void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col, THTensor *weight, THLongTensor *degree) { void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THTensor *weight) {
real *weight_data = weight->storage->data + weight->storageOffset; real *weight_data = weight->storage->data + weight->storageOffset;
real max_weight, w; real max_weight, w;
int64_t d, c; int64_t d, c;
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define THC_GENERIC_FILE "generic/serial_cuda.c" #define THC_GENERIC_FILE "generic/serial_cuda.c"
#else #else
void cluster_(serial)(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCTensor *weight, THCudaLongTensor *degree) { void cluster_(serial)(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCTensor *weight) {
cluster_kernel_(serial)(state, output, row, col, weight, degree); cluster_kernel_(serial)(state, output, row, col, degree, weight);
} }
#endif #endif
......
void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree); void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree);
void cluster_serial_Float (THLongTensor *output, THLongTensor *row, THLongTensor *col, THFloatTensor *weight, THLongTensor *degree); void cluster_serial_Float (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THFloatTensor *weight);
void cluster_serial_Double(THLongTensor *output, THLongTensor *row, THLongTensor *col, THDoubleTensor *weight, THLongTensor *degree); void cluster_serial_Double(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THDoubleTensor *weight);
void cluster_serial_Byte (THLongTensor *output, THLongTensor *row, THLongTensor *col, THByteTensor *weight, THLongTensor *degree); void cluster_serial_Byte (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THByteTensor *weight);
void cluster_serial_Char (THLongTensor *output, THLongTensor *row, THLongTensor *col, THCharTensor *weight, THLongTensor *degree); void cluster_serial_Char (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THCharTensor *weight);
void cluster_serial_Short (THLongTensor *output, THLongTensor *row, THLongTensor *col, THShortTensor *weight, THLongTensor *degree); void cluster_serial_Short (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THShortTensor *weight);
void cluster_serial_Int (THLongTensor *output, THLongTensor *row, THLongTensor *col, THIntTensor *weight, THLongTensor *degree); void cluster_serial_Int (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THIntTensor *weight);
void cluster_serial_Long (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *weight, THLongTensor *degree); void cluster_serial_Long (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree, THLongTensor *weight);
void cluster_serial_cuda(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree); void cluster_serial_cuda(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree);
void cluster_serial_cuda_Float (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaTensor *weight, THCudaLongTensor *degree); void cluster_serial_cuda_Float (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaTensor *weight);
void cluster_serial_cuda_Double(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaDoubleTensor *weight, THCudaLongTensor *degree); void cluster_serial_cuda_Double(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaDoubleTensor *weight);
void cluster_serial_cuda_Byte (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaByteTensor *weight, THCudaLongTensor *degree); void cluster_serial_cuda_Byte (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaByteTensor *weight);
void cluster_serial_cuda_Char (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaCharTensor *weight, THCudaLongTensor *degree); void cluster_serial_cuda_Char (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaCharTensor *weight);
void cluster_serial_cuda_Short (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaShortTensor *weight, THCudaLongTensor *degree); void cluster_serial_cuda_Short (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaShortTensor *weight);
void cluster_serial_cuda_Int (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaIntTensor *weight, THCudaLongTensor *degree); void cluster_serial_cuda_Int (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaIntTensor *weight);
void cluster_serial_cuda_Long (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *weight, THCudaLongTensor *degree); void cluster_serial_cuda_Long (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree, THCudaLongTensor *weight);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment