Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
cf0f8920
Commit
cf0f8920
authored
Dec 19, 2017
by
rusty1s
Browse files
rename
parent
aeb47792
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
106 additions
and
59 deletions
+106
-59
test/test_max.py
test/test_max.py
+5
-5
torch_scatter/functions/__init__.py
torch_scatter/functions/__init__.py
+8
-8
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+6
-6
torch_scatter/src/cpu.h
torch_scatter/src/cpu.h
+28
-28
torch_scatter/src/cuda.c
torch_scatter/src/cuda.c
+0
-0
torch_scatter/src/cuda.h
torch_scatter/src/cuda.h
+47
-0
torch_scatter/src/generic/cpu.c
torch_scatter/src/generic/cpu.c
+12
-12
torch_scatter/src/generic/cuda.c
torch_scatter/src/generic/cuda.c
+0
-0
No files found.
test/test_max.py
View file @
cf0f8920
...
@@ -14,15 +14,15 @@ def test_scatter_mean(str):
...
@@ -14,15 +14,15 @@ def test_scatter_mean(str):
index
=
torch
.
LongTensor
(
index
)
index
=
torch
.
LongTensor
(
index
)
output
=
input
.
new
(
2
,
6
).
fill_
(
0
)
output
=
input
.
new
(
2
,
6
).
fill_
(
0
)
expected_output
=
[[
0
,
0
,
4
,
3
,
2
,
0
],
[
2
,
4
,
3
,
0
,
0
,
0
]]
expected_output
=
[[
0
,
0
,
4
,
3
,
2
,
0
],
[
2
,
4
,
3
,
0
,
0
,
0
]]
expected_output
_arg
=
[[
-
1
,
-
1
,
3
,
4
,
0
,
1
],
[
1
,
4
,
3
,
-
1
,
-
1
,
-
1
]]
expected_
arg_
output
=
[[
-
1
,
-
1
,
3
,
4
,
0
,
1
],
[
1
,
4
,
3
,
-
1
,
-
1
,
-
1
]]
_
,
output
_arg
=
scatter_max_
(
output
,
index
,
input
,
dim
=
1
)
_
,
arg_
output
=
scatter_max_
(
output
,
index
,
input
,
dim
=
1
)
assert
output
.
tolist
()
==
expected_output
assert
output
.
tolist
()
==
expected_output
assert
output
_arg
.
tolist
()
==
expected_output
_arg
assert
arg_
output
.
tolist
()
==
expected_
arg_
output
output
,
output
_arg
=
scatter_max
(
index
,
input
,
dim
=
1
)
output
,
arg_
output
=
scatter_max
(
index
,
input
,
dim
=
1
)
assert
output
.
tolist
()
==
expected_output
assert
output
.
tolist
()
==
expected_output
assert
output
_arg
.
tolist
()
==
expected_output
_arg
assert
arg_
output
.
tolist
()
==
expected_
arg_
output
output
=
Variable
(
output
).
fill_
(
0
)
output
=
Variable
(
output
).
fill_
(
0
)
index
=
Variable
(
index
)
index
=
Variable
(
index
)
...
...
torch_scatter/functions/__init__.py
View file @
cf0f8920
...
@@ -51,10 +51,10 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
...
@@ -51,10 +51,10 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
"""If multiple indices reference the same location, their
"""If multiple indices reference the same location, their
contributions average."""
contributions average."""
output
_count
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
num_
output
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
output
_count
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
num_
output
)
output
_count
[
output_coun
t
==
0
]
=
1
num_
output
[
num_outpu
t
==
0
]
=
1
output
/=
output
_count
output
/=
num_
output
return
output
return
output
...
@@ -66,8 +66,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
...
@@ -66,8 +66,8 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def
scatter_max_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_max_
(
output
,
index
,
input
,
dim
=
0
):
"""If multiple indices reference the same location, the maximal
"""If multiple indices reference the same location, the maximal
contribution gets taken."""
contribution gets taken."""
output
_arg
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
arg_
output
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
output
_arg
)
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
arg_
output
)
def
scatter_max
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_max
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
@@ -78,8 +78,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
...
@@ -78,8 +78,8 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def
scatter_min_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_min_
(
output
,
index
,
input
,
dim
=
0
):
"""If multiple indices reference the same location, the minimal
"""If multiple indices reference the same location, the minimal
contribution gets taken."""
contribution gets taken."""
output
_arg
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
arg_
output
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
output
_arg
)
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
arg_
output
)
def
scatter_min
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_min
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
...
torch_scatter/functions/scatter.py
View file @
cf0f8920
...
@@ -6,7 +6,7 @@ from torch.autograd import Function
...
@@ -6,7 +6,7 @@ from torch.autograd import Function
from
.._ext
import
ffi
from
.._ext
import
ffi
def
has_output
_arg
(
name
):
def
has_
arg_
output
(
name
):
return
name
in
[
'max'
,
'min'
]
return
name
in
[
'max'
,
'min'
]
...
@@ -35,7 +35,7 @@ def _scatter(name, dim, *data):
...
@@ -35,7 +35,7 @@ def _scatter(name, dim, *data):
typename
=
type
(
data
[
0
]).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
data
[
0
]).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
ffi
,
'scatter_{}_{}'
.
format
(
name
,
typename
))
func
=
getattr
(
ffi
,
'scatter_{}_{}'
.
format
(
name
,
typename
))
func
(
dim
,
*
data
)
func
(
dim
,
*
data
)
return
(
data
[
0
],
data
[
3
])
if
has_output
_arg
(
name
)
else
data
[
0
]
return
(
data
[
0
],
data
[
3
])
if
has_
arg_
output
(
name
)
else
data
[
0
]
def
index_backward
(
dim
,
index
,
grad
,
grad_arg
):
def
index_backward
(
dim
,
index
,
grad
,
grad_arg
):
...
@@ -62,8 +62,8 @@ class _Scatter(Function):
...
@@ -62,8 +62,8 @@ class _Scatter(Function):
# `scatter_min` and `scatter_max` additionally return the `argmax`
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the
# respectively `argmin`. In addition, we need to save the
# `output
_arg
` for the backward pass.
# `
arg_
output` for the backward pass.
if
has_output
_arg
(
self
.
name
):
if
has_
arg_
output
(
self
.
name
):
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
return
data
[
0
],
data
[
3
]
return
data
[
0
],
data
[
3
]
else
:
else
:
...
@@ -78,11 +78,11 @@ class _Scatter(Function):
...
@@ -78,11 +78,11 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
# `scatter_min` was used.
if
self
.
needs_input_grad
[
2
]
and
not
has_output
_arg
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
not
has_
arg_
output
(
self
.
name
):
index
,
=
self
.
saved_variables
index
,
=
self
.
saved_variables
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
if
self
.
needs_input_grad
[
2
]
and
has_output
_arg
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
has_
arg_
output
(
self
.
name
):
index
,
grad_arg
=
self
.
saved_variables
index
,
grad_arg
=
self
.
saved_variables
data
=
(
index
.
data
,
data
[
0
],
grad_arg
.
data
)
data
=
(
index
.
data
,
data
[
0
],
grad_arg
.
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
...
...
torch_scatter/src/cpu.h
View file @
cf0f8920
...
@@ -14,34 +14,34 @@ void scatter_div_Short (int dim, THShortTensor *output, THLongTensor *index, TH
...
@@ -14,34 +14,34 @@ void scatter_div_Short (int dim, THShortTensor *output, THLongTensor *index, TH
void
scatter_div_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
);
void
scatter_div_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
);
void
scatter_div_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
);
void
scatter_div_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
);
void
scatter_mean_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THFloatTensor
*
output
_count
);
void
scatter_mean_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THFloatTensor
*
num_
output
);
void
scatter_mean_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THDoubleTensor
*
output
_count
);
void
scatter_mean_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THDoubleTensor
*
num_
output
);
void
scatter_mean_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
THByteTensor
*
output
_count
);
void
scatter_mean_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
THByteTensor
*
num_
output
);
void
scatter_mean_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
THCharTensor
*
output
_count
);
void
scatter_mean_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
THCharTensor
*
num_
output
);
void
scatter_mean_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THShortTensor
*
output
_count
);
void
scatter_mean_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THShortTensor
*
num_
output
);
void
scatter_mean_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THIntTensor
*
output
_count
);
void
scatter_mean_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THIntTensor
*
num_
output
);
void
scatter_mean_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
output
_count
);
void
scatter_mean_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
num_
output
);
void
scatter_max_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_max_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_max_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_max_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_max_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_max_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_max_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_max_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_max_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_max_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_max_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_max_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_max_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_max_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_min_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_min_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_min_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_min_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_min_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_min_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_min_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_min_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_min_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_min_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_min_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_min_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THLongTensor
*
arg_
output
);
void
scatter_min_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
output
_arg
);
void
scatter_min_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
arg_
output
);
void
index_backward_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
grad
,
THLongTensor
*
grad
_arg
);
void
index_backward_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
grad
,
THLongTensor
*
arg_
grad
);
void
index_backward_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
grad
,
THLongTensor
*
grad
_arg
);
void
index_backward_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
grad
,
THLongTensor
*
arg_
grad
);
void
index_backward_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
grad
,
THLongTensor
*
grad
_arg
);
void
index_backward_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
grad
,
THLongTensor
*
arg_
grad
);
void
index_backward_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
grad
,
THLongTensor
*
grad
_arg
);
void
index_backward_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
grad
,
THLongTensor
*
arg_
grad
);
void
index_backward_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
grad
,
THLongTensor
*
grad
_arg
);
void
index_backward_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
grad
,
THLongTensor
*
arg_
grad
);
void
index_backward_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
grad
,
THLongTensor
*
grad
_arg
);
void
index_backward_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
grad
,
THLongTensor
*
arg_
grad
);
void
index_backward_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
grad
,
THLongTensor
*
grad
_arg
);
void
index_backward_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
grad
,
THLongTensor
*
arg_
grad
);
torch_scatter/src/cuda.c
0 → 100644
View file @
cf0f8920
torch_scatter/src/cuda.h
0 → 100644
View file @
cf0f8920
void
scatter_mul_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
);
void
scatter_mul_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
);
void
scatter_mul_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
);
void
scatter_mul_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
);
void
scatter_mul_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
);
void
scatter_mul_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_mul_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_div_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
);
void
scatter_div_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
);
void
scatter_div_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
);
void
scatter_div_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
);
void
scatter_div_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
);
void
scatter_div_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_div_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_mean_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaTensor
*
output_count
);
void
scatter_mean_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
output_count
);
void
scatter_mean_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaByteTensor
*
output_count
);
void
scatter_mean_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaCharTensor
*
output_count
);
void
scatter_mean_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaShortTensor
*
output_count
);
void
scatter_mean_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaIntTensor
*
output_count
);
void
scatter_mean_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
output_count
);
void
scatter_max_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_min_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_min_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_min_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_min_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_min_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_min_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_min_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
index_backward_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
grad
,
THCudaLongTensor
*
arg_grad
);
void
index_backward_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
grad
,
THCudaLongTensor
*
arg_grad
);
void
index_backward_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
grad
,
THCudaLongTensor
*
arg_grad
);
void
index_backward_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
grad
,
THCudaLongTensor
*
arg_grad
);
void
index_backward_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
grad
,
THCudaLongTensor
*
arg_grad
);
void
index_backward_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
grad
,
THCudaLongTensor
*
arg_grad
);
void
index_backward_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
grad
,
THCudaLongTensor
*
arg_grad
);
torch_scatter/src/generic/cpu.c
View file @
cf0f8920
...
@@ -18,41 +18,41 @@ void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
...
@@ -18,41 +18,41 @@ void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
})
})
}
}
void
scatter_
(
mean
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THTensor
*
output
_count
)
{
void
scatter_
(
mean
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THTensor
*
num_
output
)
{
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
real
,
output
_count
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
real
,
num_
output
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
+=
input_data
[
i
];
output_data
[
index_data
[
i
]]
+=
input_data
[
i
];
output_
count_
data
[
index_data
[
i
]]
++
;
num_
output_data
[
index_data
[
i
]]
++
;
})
})
}
}
void
scatter_
(
max
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
output
_arg
)
{
void
scatter_
(
max
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
arg_
output
)
{
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
output
_arg
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
arg_
output
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
if
(
input_data
[
i
]
>=
output_data
[
index_data
[
i
]])
{
if
(
input_data
[
i
]
>=
output_data
[
index_data
[
i
]])
{
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
output_
arg_
data
[
index_data
[
i
]]
=
i
;
arg_
output_data
[
index_data
[
i
]]
=
i
;
}
}
})
})
}
}
void
scatter_
(
min
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
output
_arg
)
{
void
scatter_
(
min
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
arg_
output
)
{
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
output
_arg
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
arg_
output
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
if
(
input_data
[
i
]
<=
output_data
[
index_data
[
i
]])
{
if
(
input_data
[
i
]
<=
output_data
[
index_data
[
i
]])
{
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
output_
arg_
data
[
index_data
[
i
]]
=
i
;
arg_
output_data
[
index_data
[
i
]]
=
i
;
}
}
})
})
}
}
void
index_backward
(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
grad
,
THLongTensor
*
grad
_arg
)
{
void
index_backward
(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
grad
,
THLongTensor
*
arg_
grad
)
{
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
grad
,
int64_t
,
grad
_arg
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
grad
,
int64_t
,
arg_
grad
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
if
(
grad
_arg
_data
[
index_data
[
i
]]
==
i
)
output_data
[
i
]
=
grad_data
[
index_data
[
i
]];
if
(
arg_
grad_data
[
index_data
[
i
]]
==
i
)
output_data
[
i
]
=
grad_data
[
index_data
[
i
]];
})
})
}
}
...
...
torch_scatter/src/generic/cuda.c
0 → 100644
View file @
cf0f8920
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment