Commit 50909651 authored by rusty1s's avatar rusty1s
Browse files

replaced scatter add

parent e1180216
...@@ -8,7 +8,9 @@ def batch_slices(batch, sizes=False, include_ends=True): ...@@ -8,7 +8,9 @@ def batch_slices(batch, sizes=False, include_ends=True):
""" """
Calculates size, start and end indices for each element in a batch. Calculates size, start and end indices for each element in a batch.
""" """
size = torch.scatter_add_(torch.ones_like(batch), batch) batch_size = batch.max().item() + 1
size = batch.new_zeros(batch_size).scatter_add_(0, batch,
torch.ones_like(batch))
cumsum = torch.cumsum(size, dim=0) cumsum = torch.cumsum(size, dim=0)
starts = cumsum - size starts = cumsum - size
ends = cumsum - 1 ends = cumsum - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment