Commit c20e6b58 authored by rusty1s's avatar rusty1s
Browse files

rename

parent 235d586b
...@@ -2,7 +2,7 @@ from __future__ import division ...@@ -2,7 +2,7 @@ from __future__ import division
import torch import torch
from .utils.ffi import get_dynamic_func from .utils.ffi import get_typed_func
from .utils.consecutive import consecutive from .utils.consecutive import consecutive
...@@ -70,7 +70,7 @@ def _grid_cluster(position, size, cluster_size): ...@@ -70,7 +70,7 @@ def _grid_cluster(position, size, cluster_size):
cluster = cluster_size.new(torch.Size(list(position.size())[:-1])) cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
cluster = cluster.unsqueeze(dim=-1) cluster = cluster.unsqueeze(dim=-1)
func = get_dynamic_func('grid', position) func = get_typed_func('grid', position)
func(C, cluster, position, size, cluster_size) func(C, cluster, position, size, cluster_size)
cluster = cluster.squeeze(dim=-1) cluster = cluster.squeeze(dim=-1)
......
...@@ -6,7 +6,7 @@ def get_func(name, tensor): ...@@ -6,7 +6,7 @@ def get_func(name, tensor):
return getattr(ffi, 'cluster_{}{}'.format(name, cuda)) return getattr(ffi, 'cluster_{}{}'.format(name, cuda))
def get_dynamic_func(name, tensor): def get_typed_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '') typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else '' cuda = 'cuda_' if tensor.is_cuda else ''
return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename)) return getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment