Commit 3f610bcf authored by rusty1s's avatar rusty1s
Browse files

own file for every function

parent 0b0d099e
from .scatter import scatter
from .utils import gen_filled_tensor, gen_output
from .add import scatter_add_, scatter_add from .add import scatter_add_, scatter_add
from .sub import scatter_sub_, scatter_sub
from .mul import scatter_mul_, scatter_mul
def scatter_sub_(output, index, input, dim=0): from .div import scatter_div_, scatter_div
"""If multiple indices reference the same location, their negated from .mean import scatter_mean_, scatter_mean
contributions add.""" from .max import scatter_max_, scatter_max
return output.scatter_add_(dim, index, -input) from .min import scatter_min_, scatter_min
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_sub_(output, index, input, dim)
def scatter_mul_(output, index, input, dim=0):
"""If multiple indices reference the same location, their
contributions multiply."""
return scatter('mul', dim, output, index, input)
def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mul_(output, index, input, dim)
def scatter_div_(output, index, input, dim=0):
"""If multiple indices reference the same location, their
contributions divide."""
return scatter('div', dim, output, index, input)
def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
scatter_div_(output, index, input, dim)
def scatter_mean_(output, index, input, dim=0):
"""If multiple indices reference the same location, their
contributions average."""
num_output = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, num_output)
num_output[num_output == 0] = 1
output /= num_output
return output
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mean_(output, index, input, dim)
def scatter_max_(output, index, input, dim=0):
"""If multiple indices reference the same location, the maximal
contribution gets taken.
:rtype: (:class:`Tensor`, :class:`LongTensor`)
"""
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, arg_output)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_max_(output, index, input, dim)
def scatter_min_(output, index, input, dim=0):
"""If multiple indices reference the same location, the minimal
contribution gets taken."""
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, arg_output)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_min_(output, index, input, dim)
__all__ = [ __all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub', 'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
......
from .scatter import scatter
from .utils import gen_output
def scatter_div_(output, index, input, dim=0):
"""If multiple indices reference the same location, their
contributions divide."""
return scatter('div', dim, output, index, input)
def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
scatter_div_(output, index, input, dim)
from .scatter import scatter
from .utils import gen_filled_tensor, gen_output
def scatter_max_(output, index, input, dim=0):
"""If multiple indices reference the same location, the maximal
contribution gets taken.
:rtype: (:class:`Tensor`, :class:`LongTensor`)
"""
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, arg_output)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_max_(output, index, input, dim)
from .scatter import scatter
from .utils import gen_filled_tensor, gen_output
def scatter_mean_(output, index, input, dim=0):
"""If multiple indices reference the same location, their
contributions average."""
num_output = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, num_output)
num_output[num_output == 0] = 1
output /= num_output
return output
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mean_(output, index, input, dim)
from .scatter import scatter
from .utils import gen_filled_tensor, gen_output
def scatter_min_(output, index, input, dim=0):
"""If multiple indices reference the same location, the minimal
contribution gets taken."""
arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, arg_output)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_min_(output, index, input, dim)
from .scatter import scatter
from .utils import gen_output
def scatter_mul_(output, index, input, dim=0):
"""If multiple indices reference the same location, their
contributions multiply."""
return scatter('mul', dim, output, index, input)
def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mul_(output, index, input, dim)
from .utils import gen_output
def scatter_sub_(output, index, input, dim=0):
"""If multiple indices reference the same location, their negated
contributions add."""
return output.scatter_add_(dim, index, -input)
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_sub_(output, index, input, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment