Commit 50909651 authored by rusty1s's avatar rusty1s
Browse files

replaced scatter add

parent e1180216
......@@ -8,7 +8,9 @@ def batch_slices(batch, sizes=False, include_ends=True):
"""
Calculates size, start and end indices for each element in a batch.
"""
size = torch.scatter_add_(torch.ones_like(batch), batch)
batch_size = batch.max().item() + 1
size = batch.new_zeros(batch_size).scatter_add_(0, batch,
torch.ones_like(batch))
cumsum = torch.cumsum(size, dim=0)
starts = cumsum - size
ends = cumsum - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment