Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
f013eb53
Commit
f013eb53
authored
Dec 21, 2017
by
rusty1s
Browse files
index backward impl
parent
4134714f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
7 deletions
+10
-7
test/test_max.py
test/test_max.py
+7
-4
torch_scatter/kernel/kernel.cu
torch_scatter/kernel/kernel.cu
+3
-3
No files found.
test/test_max.py
View file @
f013eb53
...
...
@@ -49,6 +49,12 @@ def test_scatter_cuda_max(str):
output
,
index
,
input
=
output
.
cuda
(),
index
.
cuda
(),
input
.
cuda
()
_
,
arg_output
=
scatter_max_
(
output
,
index
,
input
,
dim
=
1
)
assert
output
.
cpu
().
tolist
()
==
expected_output
assert
arg_output
.
cpu
().
tolist
()
==
expected_arg_output
output
,
arg_output
=
scatter_max
(
index
,
input
,
dim
=
1
)
assert
output
.
cpu
().
tolist
()
==
expected_output
assert
arg_output
.
cpu
().
tolist
()
==
expected_arg_output
output
=
Variable
(
output
).
fill_
(
0
)
index
=
Variable
(
index
)
...
...
@@ -60,7 +66,4 @@ def test_scatter_cuda_max(str):
expected_grad_input
=
[[
50
,
60
,
0
,
30
,
40
],
[
0
,
15
,
0
,
35
,
25
]]
output
.
backward
(
grad_output
)
print
(
input
.
grad
.
data
)
# print(input)
# assert input.grad.data.tolist() == expected_grad_input
assert
input
.
grad
.
data
.
cpu
().
tolist
()
==
expected_grad_input
torch_scatter/kernel/kernel.cu
View file @
f013eb53
...
...
@@ -71,9 +71,9 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
template
<
typename
Real
,
int
Dims
>
__global__
void
indexBackwardKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
grad
,
TensorInfo
<
int64_t
>
arg
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
/*
int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0;
*/
/*
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset,
grad, &gradOffset,
output, &outputOffset, arg, &argOffset);
*/
/*
if (
eq(input.data[inputOffset],
output.data[outputOffset]
)) arg
.data[
arg
Offset]
= inputOffset % input.size[dim]; */
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
gradOffset
=
0
;
int
argOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
output
,
&
outputOffset
,
grad
,
&
gradOffset
,
arg
,
&
argOffset
);
if
(
arg
.
data
[
argOffset
]
==
outputOffset
%
output
.
size
[
dim
])
output
.
data
[
outputOffset
]
=
grad
.
data
[
grad
Offset
]
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment