Commit 61084dfe authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent bbc25706
...@@ -8,8 +8,9 @@ def grid_cluster(position, size, batch=None): ...@@ -8,8 +8,9 @@ def grid_cluster(position, size, batch=None):
if batch is not None: if batch is not None:
batch = batch.type_as(position) batch = batch.type_as(position)
position = torch.cat([position, batch], dim=position.dim() - 1) size = torch.cat([size.new(1).fill_(1), size], dim=0)
size = torch.cat([size, size.new(1).fill_(1)], dim=0) dim = position.dim()
position = torch.cat([batch.unsqueeze(dim - 1), position], dim=dim - 1)
dim = position.dim() dim = position.dim()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment