Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
d3aabdf3
"git@developer.sourcefind.cn:OpenDAS/dlib.git" did not exist on "67e6957b1e3934bf542afd81b061ade8460ae6f2"
Commit
d3aabdf3
authored
Feb 02, 2020
by
rusty1s
Browse files
fix negative dim in scatter_mean
parent
ff3be8e3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
6 deletions
+14
-6
csrc/scatter.cpp
csrc/scatter.cpp
+1
-0
test/test_broadcasting.py
test/test_broadcasting.py
+13
-6
No files found.
csrc/scatter.cpp
View file @
d3aabdf3
...
@@ -71,6 +71,7 @@ public:
...
@@ -71,6 +71,7 @@ public:
Variable
index
,
int64_t
dim
,
Variable
index
,
int64_t
dim
,
torch
::
optional
<
Variable
>
optional_out
,
torch
::
optional
<
Variable
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
torch
::
optional
<
int64_t
>
dim_size
)
{
dim
=
dim
<
0
?
src
.
dim
()
+
dim
:
dim
;
ctx
->
saved_data
[
"dim"
]
=
dim
;
ctx
->
saved_data
[
"dim"
]
=
dim
;
ctx
->
saved_data
[
"src_shape"
]
=
src
.
sizes
();
ctx
->
saved_data
[
"src_shape"
]
=
src
.
sizes
();
...
...
test/test_broadcasting.py
View file @
d3aabdf3
from
itertools
import
product
import
pytest
import
pytest
import
torch
import
torch
from
torch_scatter
import
scatter
_add
from
torch_scatter
import
scatter
from
.utils
import
devices
from
.utils
import
reductions
,
devices
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'
reduce,
device'
,
product
(
reductions
,
devices
)
)
def
test_broadcasting
(
device
):
def
test_broadcasting
(
reduce
,
device
):
B
,
C
,
H
,
W
=
(
4
,
3
,
8
,
8
)
B
,
C
,
H
,
W
=
(
4
,
3
,
8
,
8
)
src
=
torch
.
randn
((
B
,
C
,
H
,
W
),
device
=
device
)
index
=
torch
.
randint
(
0
,
H
,
(
H
,
)).
to
(
device
,
torch
.
long
)
out
=
scatter
(
src
,
index
,
dim
=
2
,
dim_size
=
H
,
reduce
=
reduce
)
assert
out
.
size
()
==
(
B
,
C
,
H
,
W
)
src
=
torch
.
randn
((
B
,
C
,
H
,
W
),
device
=
device
)
src
=
torch
.
randn
((
B
,
C
,
H
,
W
),
device
=
device
)
index
=
torch
.
randint
(
0
,
H
,
(
B
,
1
,
H
,
W
)).
to
(
device
,
torch
.
long
)
index
=
torch
.
randint
(
0
,
H
,
(
B
,
1
,
H
,
W
)).
to
(
device
,
torch
.
long
)
out
=
scatter
_add
(
src
,
index
,
dim
=
2
,
dim_size
=
H
)
out
=
scatter
(
src
,
index
,
dim
=
2
,
dim_size
=
H
,
reduce
=
reduce
)
assert
out
.
size
()
==
(
B
,
C
,
H
,
W
)
assert
out
.
size
()
==
(
B
,
C
,
H
,
W
)
src
=
torch
.
randn
((
B
,
C
,
H
,
W
),
device
=
device
)
src
=
torch
.
randn
((
B
,
C
,
H
,
W
),
device
=
device
)
index
=
torch
.
randint
(
0
,
H
,
(
H
,
)).
to
(
device
,
torch
.
long
)
index
=
torch
.
randint
(
0
,
H
,
(
H
,
)).
to
(
device
,
torch
.
long
)
out
=
scatter
_add
(
src
,
index
,
dim
=
2
,
dim_size
=
H
)
out
=
scatter
(
src
,
index
,
dim
=
2
,
dim_size
=
H
,
reduce
=
reduce
)
assert
out
.
size
()
==
(
B
,
C
,
H
,
W
)
assert
out
.
size
()
==
(
B
,
C
,
H
,
W
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment