Commit 1c4ef780 authored by rusty1s's avatar rusty1s
Browse files

added max min functions

parent fa305127
...@@ -61,8 +61,31 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0): ...@@ -61,8 +61,31 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
return scatter_mean_(output, index, input, dim) return scatter_mean_(output, index, input, dim)
def scatter_max_(output, index, input, dim=0):
output_index = index.new(output.size()).fill_(-1)
scatter('max', dim, output, index, input, output_index)
return output, output_index
def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_max_(output, index, input, dim)
def scatter_min_(output, index, input, dim=0):
output_index = index.new(output.size()).fill_(-1)
scatter('min', dim, output, index, input, output_index)
return output, output_index
def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_min_(output, index, input, dim)
__all__ = [ __all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub', 'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div', 'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
'scatter_mean_', 'scatter_mean' 'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
'scatter_min_', 'scatter_min'
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment