"docs/zh_cn/user_guides/config.md" did not exist on "e37c87779e73f5ea125dbfb8717a2e498da95923"
mean.py 613 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
from .scatter import scatter
from .utils import gen_filled_tensor, gen_output


def scatter_mean_(output, index, input, dim=0):
    """If multiple indices reference the same location, their
rusty1s's avatar
rusty1s committed
7
    **contributions average**."""
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
    num_output = gen_filled_tensor(output, output.size(), fill_value=0)
    scatter('mean', dim, output, index, input, num_output)
    num_output[num_output == 0] = 1
    output /= num_output
    return output


rusty1s's avatar
rename  
rusty1s committed
15
16
def scatter_mean(index, input, dim=0, size=None, fill_value=0):
    output = gen_output(index, input, dim, size, fill_value)
rusty1s's avatar
rusty1s committed
17
    return scatter_mean_(output, index, input, dim)