Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
573443b6
Commit
573443b6
authored
Dec 19, 2017
by
rusty1s
Browse files
rename
parent
cf0f8920
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
11 deletions
+11
-11
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+4
-4
torch_scatter/src/cuda.h
torch_scatter/src/cuda.h
+7
-7
No files found.
torch_scatter/functions/scatter.py
View file @
573443b6
...
@@ -38,11 +38,11 @@ def _scatter(name, dim, *data):
...
@@ -38,11 +38,11 @@ def _scatter(name, dim, *data):
return
(
data
[
0
],
data
[
3
])
if
has_arg_output
(
name
)
else
data
[
0
]
return
(
data
[
0
],
data
[
3
])
if
has_arg_output
(
name
)
else
data
[
0
]
def
index_backward
(
dim
,
index
,
grad
,
grad
_arg
):
def
index_backward
(
dim
,
index
,
grad
,
arg_
grad
):
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
ffi
,
'index_backward_{}'
.
format
(
typename
))
func
=
getattr
(
ffi
,
'index_backward_{}'
.
format
(
typename
))
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
func
(
dim
,
output
,
index
,
grad
,
grad
_arg
)
func
(
dim
,
output
,
index
,
grad
,
arg_
grad
)
return
output
return
output
...
@@ -83,8 +83,8 @@ class _Scatter(Function):
...
@@ -83,8 +83,8 @@ class _Scatter(Function):
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
if
self
.
needs_input_grad
[
2
]
and
has_arg_output
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
has_arg_output
(
self
.
name
):
index
,
grad
_arg
=
self
.
saved_variables
index
,
arg_
grad
=
self
.
saved_variables
data
=
(
index
.
data
,
data
[
0
],
grad
_arg
.
data
)
data
=
(
index
.
data
,
data
[
0
],
arg_
grad
.
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
# Return and fill with empty grads for none-differentiable passed
# Return and fill with empty grads for none-differentiable passed
...
...
torch_scatter/src/cuda.h
View file @
573443b6
...
@@ -14,13 +14,13 @@ void scatter_div_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTens
...
@@ -14,13 +14,13 @@ void scatter_div_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTens
void
scatter_div_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_div_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_div_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_div_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_mean_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaTensor
*
output
_count
);
void
scatter_mean_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaTensor
*
num_
output
);
void
scatter_mean_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
output
_count
);
void
scatter_mean_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
num_
output
);
void
scatter_mean_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaByteTensor
*
output
_count
);
void
scatter_mean_cuda_Byte
(
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaByteTensor
*
num_
output
);
void
scatter_mean_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaCharTensor
*
output
_count
);
void
scatter_mean_cuda_Char
(
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaCharTensor
*
num_
output
);
void
scatter_mean_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaShortTensor
*
output
_count
);
void
scatter_mean_cuda_Short
(
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaShortTensor
*
num_
output
);
void
scatter_mean_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaIntTensor
*
output
_count
);
void
scatter_mean_cuda_Int
(
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaIntTensor
*
num_
output
);
void
scatter_mean_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
output
_count
);
void
scatter_mean_cuda_Long
(
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
num_
output
);
void
scatter_max_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Float
(
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg_output
);
void
scatter_max_cuda_Double
(
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg_output
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment