Commit 43fde05a authored by rusty1s's avatar rusty1s
Browse files

added assert checks

parent 61084dfe
......@@ -4,23 +4,25 @@ from .utils import get_func
def grid_cluster(position, size, batch=None):
# TODO: Check types and sizes
if batch is not None:
batch = batch.type_as(position)
size = torch.cat([size.new(1).fill_(1), size], dim=0)
dim = position.dim()
position = torch.cat([batch.unsqueeze(dim - 1), position], dim=dim - 1)
# Allow one-dimensional positions.
if position.dim() == 1:
position = position.unsqueeze(-1)
dim = position.dim()
assert size.dim() == 1, 'Size tensor must be one-dimensional'
assert position.size(-1) == size.size(-1), (
'Last dimension of position tensor must have same size as size tensor')
# Allow one-dimensional positions.
if dim == 1:
position = position.unsqueeze(1)
dim += 1
# If given, append batch to position tensor.
if batch is not None:
batch = batch.unsqueeze(-1).type_as(position)
assert position.size()[:-1] == batch.size()[:-1], (
'Position tensor must have same size as batch tensor apart from '
'the last dimension')
position = torch.cat([batch, position], dim=-1)
size = torch.cat([size.new(1).fill_(1), size], dim=-1)
# Translate to minimal positive positions.
min = position.min(dim=dim - 2, keepdim=True)[0]
min = position.min(dim=-2, keepdim=True)[0]
position = position - min
# Compute cluster count for each dimension.
......@@ -37,8 +39,9 @@ def grid_cluster(position, size, batch=None):
cluster = c_max.new(torch.Size(s))
# Fill cluster tensor and reshape.
size = size.type_as(position)
func = get_func('grid', position)
func(C, cluster, position, size, c_max)
cluster = cluster.squeeze(dim=dim - 1)
cluster = cluster.squeeze(dim=-1)
return cluster
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment