sub.py 403 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
from .utils import gen_output


def scatter_sub_(output, index, input, dim=0):
    """If multiple indices reference the same location, their negated
    contributions add."""
    return output.scatter_add_(dim, index, -input)


rusty1s's avatar
rename  
rusty1s committed
10
11
def scatter_sub(index, input, dim=0, size=None, fill_value=0):
    output = gen_output(index, input, dim, size, fill_value)
rusty1s's avatar
rusty1s committed
12
    return scatter_sub_(output, index, input, dim)