mul.py 441 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
from .scatter import scatter
from .utils import gen_output


def scatter_mul_(output, index, input, dim=0):
    """If multiple indices reference the same location, their
    contributions multiply."""
    return scatter('mul', dim, output, index, input)


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)