Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
3472bc29
Commit
3472bc29
authored
Dec 18, 2017
by
rusty1s
Browse files
debug
parent
143a57ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
13 deletions
+15
-13
test/test_max.py
test/test_max.py
+12
-11
torch_scatter/src/generic/cpu.c
torch_scatter/src/generic/cpu.c
+3
-2
No files found.
test/test_max.py
View file @
3472bc29
...
@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor
...
@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor
# @pytest.mark.parametrize('str', tensor_strs)
# @pytest.mark.parametrize('str', tensor_strs)
@
pytest
.
mark
.
parametrize
(
'str'
,
[
'
Double
Tensor'
])
@
pytest
.
mark
.
parametrize
(
'str'
,
[
'
Int
Tensor'
])
def
test_scatter_mean
(
str
):
def
test_scatter_mean
(
str
):
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
...
@@ -25,22 +25,23 @@ def test_scatter_mean(str):
...
@@ -25,22 +25,23 @@ def test_scatter_mean(str):
assert
output
.
tolist
()
==
expected_output
assert
output
.
tolist
()
==
expected_output
assert
output_index
.
tolist
()
==
expected_output_index
assert
output_index
.
tolist
()
==
expected_output_index
output
=
Variable
(
output
).
fill_
(
0
)
#
output = Variable(output).fill_(0)
index
=
Variable
(
index
)
#
index = Variable(index)
input
=
Variable
(
input
,
requires_grad
=
True
)
#
input = Variable(input, requires_grad=True)
scatter_max_
(
output
,
index
,
input
,
dim
=
1
)
#
scatter_max_(output, index, input, dim=1)
grad_output
=
[[
10
,
20
,
30
,
40
,
50
,
60
],
[
15
,
25
,
35
,
45
,
55
,
65
]]
#
grad_output = [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]]
grad_output
=
Tensor
(
str
,
grad_output
)
#
grad_output = Tensor(str, grad_output)
output
.
backward
(
grad_output
)
#
output.backward(grad_output)
# assert index.data.tolist() == input.grad.data.tolist()
# assert index.data.tolist() == input.grad.data.tolist()
# output = Variable(torch.FloatTensor([0, 0, 0, 0, 0]))
# output = Variable(torch.FloatTensor([0, 0, 0, 0, 0]))
index
=
Variable
(
torch
.
LongTensor
([
3
,
4
,
4
,
2
,
1
]))
index
=
Variable
(
torch
.
LongTensor
([
3
,
4
,
4
,
2
,
1
]))
input
=
Variable
(
torch
.
Floa
tTensor
([
1
,
2
,
3
,
4
,
5
]),
requires_grad
=
True
)
input
=
Variable
(
torch
.
In
tTensor
([
1
,
2
,
3
,
4
,
5
]),
requires_grad
=
True
)
output
,
output_index
=
scatter_max
(
index
,
input
)
output
,
output_index
=
scatter_max
(
index
,
input
)
# print(output, output_index)
# print(output_index)
# print(output_index)
output
.
backward
(
torch
.
Floa
tTensor
([
10
,
20
,
30
,
40
]))
output
.
backward
(
torch
.
In
tTensor
([
10
,
20
,
30
,
40
]))
print
(
input
.
grad
)
#
print(input.grad)
torch_scatter/src/generic/cpu.c
View file @
3472bc29
...
@@ -66,11 +66,12 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
...
@@ -66,11 +66,12 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
}
}
void
index_backward
(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
grad
,
THLongTensor
*
grad_index
)
{
void
index_backward
(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
grad
,
THLongTensor
*
grad_index
)
{
int64_t
idx
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
grad
,
int64_t
,
grad_index
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
grad
,
int64_t
,
grad_index
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
int64_t
idx
=
*
(
index_data
+
i
*
index_stride
);
idx
=
*
(
index_data
+
i
*
index_stride
);
/* if (grad_index_data[index_data[i]] == i) { */
/* if (grad_index_data[index_data[i]] == i) { */
/*
printf("i: %i,
", i); */
printf
(
"i: %
ll
i,
idx: %lli grad_index: %i grad: %i
\n
"
,
i
,
idx
,
*
(
grad_index_data
+
idx
*
grad_index_stride
),
*
(
grad_data
+
idx
*
grad_stride
));
/* output_data[i] = grad_data[idx]; */
/* output_data[i] = grad_data[idx]; */
/* } */
/* } */
})
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment