Commit 32a60c2d authored by rusty1s's avatar rusty1s
Browse files

doc

parent a4214f3f
......@@ -47,6 +47,8 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
def scatter_mean_(output, index, input, dim=0):
"""If multiple indices reference the same location, their
contributions average."""
output_count = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, output_count)
output_count[output_count == 0] = 1
......@@ -60,6 +62,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def scatter_max_(output, index, input, dim=0):
"""If multiple indices reference the same location, the maximal
contribution gets taken."""
output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, output_arg)
......@@ -70,6 +74,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def scatter_min_(output, index, input, dim=0):
"""If multiple indices reference the same location, the minimal
contribution gets taken."""
output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, output_arg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment