Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
88dd792e
Commit
88dd792e
authored
Feb 04, 2020
by
rusty1s
Browse files
fix zero element tensors
parent
bf1f1014
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
132 additions
and
51 deletions
+132
-51
csrc/cpu/scatter_cpu.cpp
csrc/cpu/scatter_cpu.cpp
+5
-0
csrc/cpu/segment_coo_cpu.cpp
csrc/cpu/segment_coo_cpu.cpp
+12
-7
csrc/cpu/segment_csr_cpu.cpp
csrc/cpu/segment_csr_cpu.cpp
+13
-4
csrc/cuda/scatter_cuda.cu
csrc/cuda/scatter_cuda.cu
+5
-0
csrc/cuda/segment_coo_cuda.cu
csrc/cuda/segment_coo_cuda.cu
+16
-6
csrc/cuda/segment_csr_cuda.cu
csrc/cuda/segment_csr_cuda.cu
+16
-7
csrc/segment_csr.cpp
csrc/segment_csr.cpp
+10
-8
test/test_zero_tensors.py
test/test_zero_tensors.py
+33
-7
torch_scatter/scatter.py
torch_scatter/scatter.py
+22
-12
No files found.
csrc/cpu/scatter_cpu.cpp
View file @
88dd792e
...
...
@@ -29,6 +29,8 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
auto
sizes
=
src
.
sizes
().
vec
();
if
(
dim_size
.
has_value
())
sizes
[
dim
]
=
dim_size
.
value
();
else
if
(
index
.
numel
()
==
0
)
sizes
[
dim
]
=
0
;
else
sizes
[
dim
]
=
1
+
*
index
.
max
().
data_ptr
<
int64_t
>
();
out
=
torch
::
empty
(
sizes
,
src
.
options
());
...
...
@@ -41,6 +43,9 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
if
(
index
.
numel
()
==
0
)
return
std
::
make_tuple
(
out
,
arg_out
);
auto
B
=
1
;
for
(
auto
i
=
0
;
i
<
dim
;
i
++
)
B
*=
src
.
size
(
i
);
...
...
csrc/cpu/segment_coo_cpu.cpp
View file @
88dd792e
...
...
@@ -34,6 +34,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
sizes
=
src
.
sizes
().
vec
();
if
(
dim_size
.
has_value
())
sizes
[
dim
]
=
dim_size
.
value
();
else
if
(
index
.
numel
()
==
0
)
sizes
[
dim
]
=
0
;
else
sizes
[
dim
]
=
1
+
*
index
.
max
().
data_ptr
<
int64_t
>
();
out
=
torch
::
empty
(
sizes
,
src
.
options
());
...
...
@@ -44,15 +46,15 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
torch
::
full_like
(
out
,
src
.
size
(
dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
torch
::
optional
<
torch
::
Tensor
>
count
=
torch
::
nullopt
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
)
{
}
else
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
)
{
auto
sizes
=
index
.
sizes
().
vec
();
sizes
[
dim
]
=
out
.
size
(
dim
);
c
ou
n
t
=
torch
::
zeros
(
sizes
,
out
.
options
());
arg_
out
=
torch
::
zeros
(
sizes
,
out
.
options
());
}
if
(
index
.
numel
()
==
0
)
return
std
::
make_tuple
(
out
,
arg_out
);
auto
B
=
index
.
numel
()
/
src
.
size
(
dim
);
auto
E
=
src
.
size
(
dim
);
auto
K
=
src
.
numel
()
/
index
.
numel
();
...
...
@@ -72,7 +74,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if
(
!
optional_out
.
has_value
())
out
.
fill_
(
Reducer
<
scalar_t
,
REDUCE
>::
init
());
if
(
REDUCE
==
MEAN
)
count_data
=
c
ou
n
t
.
value
().
data_ptr
<
scalar_t
>
();
count_data
=
arg_
out
.
value
().
data_ptr
<
scalar_t
>
();
for
(
auto
b
=
0
;
b
<
B
;
b
++
)
{
auto
offset
=
IndexToOffset
<
int64_t
>::
get
(
b
*
E
,
index_info
);
...
...
@@ -122,7 +124,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
out
.
masked_fill_
(
out
==
Reducer
<
scalar_t
,
REDUCE
>::
init
(),
(
scalar_t
)
0
);
if
(
REDUCE
==
MEAN
)
arg_out
=
count
;
arg_out
.
value
().
clamp_
(
1
)
;
});
});
...
...
@@ -156,6 +158,9 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
if
(
index
.
numel
()
==
0
)
return
out
;
auto
B
=
index
.
numel
()
/
out
.
size
(
dim
);
auto
E
=
index
.
size
(
dim
);
auto
K
=
out
.
numel
()
/
index
.
numel
();
...
...
csrc/cpu/segment_csr_cpu.cpp
View file @
88dd792e
...
...
@@ -30,10 +30,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
for
(
auto
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
dim
)
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
CHECK_INPUT
(
out
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
CHECK_INPUT
(
src
.
numel
()
==
0
||
out
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
}
else
{
sizes
=
src
.
sizes
().
vec
();
sizes
[
dim
]
=
indptr
.
size
(
dim
)
-
1
;
sizes
[
dim
]
=
std
::
max
<
int64_t
>
(
indptr
.
size
(
dim
)
-
1
,
0
)
;
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
...
...
@@ -44,6 +44,9 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
if
(
src
.
numel
()
==
0
)
return
std
::
make_tuple
(
out
,
arg_out
);
auto
N
=
out
.
size
(
dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
out
.
numel
()
/
N
;
auto
E
=
src
.
size
(
dim
);
...
...
@@ -98,7 +101,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
indptr
=
indptr
.
expand
(
sizes
);
auto
dim
=
indptr
.
dim
()
-
1
;
CHECK_INPUT
(
src
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
CHECK_INPUT
(
src
.
size
(
dim
)
==
0
||
src
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
src
=
src
.
contiguous
();
...
...
@@ -110,10 +113,16 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
if
(
src
.
numel
()
>
0
)
sizes
[
dim
]
=
*
indptr
.
flatten
()[
-
1
].
data_ptr
<
int64_t
>
();
else
sizes
[
dim
]
=
0
;
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
if
(
src
.
numel
()
==
0
)
return
out
;
auto
N
=
src
.
size
(
dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
src
.
numel
()
/
N
;
auto
E
=
out
.
size
(
dim
);
...
...
csrc/cuda/scatter_cuda.cu
View file @
88dd792e
...
...
@@ -81,6 +81,8 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
auto
sizes
=
src
.
sizes
().
vec
();
if
(
dim_size
.
has_value
())
sizes
[
dim
]
=
dim_size
.
value
();
else
if
(
index
.
numel
()
==
0
)
sizes
[
dim
]
=
0
;
else
{
auto
d_size
=
index
.
max
().
data_ptr
<
int64_t
>
();
auto
h_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
...
...
@@ -97,6 +99,9 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
if
(
index
.
numel
()
==
0
)
return
std
::
make_tuple
(
out
,
arg_out
);
auto
B
=
1
;
for
(
auto
i
=
0
;
i
<
dim
;
i
++
)
B
*=
src
.
size
(
i
);
...
...
csrc/cuda/segment_coo_cuda.cu
View file @
88dd792e
...
...
@@ -181,6 +181,8 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
sizes
=
src
.
sizes
().
vec
();
if
(
dim_size
.
has_value
())
sizes
[
dim
]
=
dim_size
.
value
();
else
if
(
index
.
numel
()
==
0
)
sizes
[
dim
]
=
0
;
else
{
auto
d_size
=
index
.
max
().
data_ptr
<
int64_t
>
();
auto
h_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
...
...
@@ -195,8 +197,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
torch
::
full_like
(
out
,
src
.
size
(
dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
else
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
)
{
auto
sizes
=
index
.
sizes
().
vec
();
sizes
[
dim
]
=
out
.
size
(
dim
);
arg_out
=
torch
::
zeros
(
sizes
,
out
.
options
());
}
if
(
index
.
numel
()
==
0
)
return
std
::
make_tuple
(
out
,
arg_out
);
auto
E
=
index
.
numel
();
auto
E_2
=
index
.
size
(
dim
);
auto
E_1
=
index
.
numel
()
/
E_2
;
...
...
@@ -254,17 +263,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
}
if
(
REDUCE
==
MEAN
)
{
auto
sizes
=
index
.
sizes
().
vec
();
sizes
[
dim
]
=
out
.
size
(
dim
);
auto
count
=
torch
::
zeros
(
sizes
,
out
.
options
());
auto
count_data
=
count
.
data_ptr
<
scalar_t
>
();
auto
count_data
=
arg_out
.
value
().
data_ptr
<
scalar_t
>
();
segment_coo_kernel
<
scalar_t
,
SUM
,
false
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
nullptr
,
index_info
,
count_data
,
E
,
N
);
arg_out
=
count
;
arg_out
.
value
().
clamp_
(
1
);
auto
count
=
arg_out
.
value
();
for
(
int
i
=
dim
+
1
;
i
<
out
.
dim
();
i
++
)
count
=
count
.
unsqueeze
(
-
1
);
out
.
div_
(
count
.
clamp_
(
1
)
);
out
.
div_
(
count
);
}
});
});
...
...
@@ -346,6 +353,9 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
if
(
index
.
numel
()
==
0
)
return
out
;
auto
E
=
index
.
numel
();
auto
K
=
out
.
numel
()
/
E
;
auto
N
=
src
.
size
(
dim
);
...
...
csrc/cuda/segment_csr_cuda.cu
View file @
88dd792e
...
...
@@ -121,10 +121,10 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
dim
)
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
CHECK_INPUT
(
out
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
CHECK_INPUT
(
src
.
numel
()
==
0
||
out
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
}
else
{
sizes
=
src
.
sizes
().
vec
();
sizes
[
dim
]
=
indptr
.
size
(
dim
)
-
1
;
sizes
[
dim
]
=
std
::
max
<
int64_t
>
(
indptr
.
size
(
dim
)
-
1
,
0
)
;
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
...
...
@@ -135,6 +135,9 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
if
(
src
.
numel
()
==
0
)
return
std
::
make_tuple
(
out
,
arg_out
);
auto
N
=
out
.
size
(
dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
out
.
numel
()
/
N
;
auto
E
=
src
.
size
(
dim
);
...
...
@@ -226,7 +229,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
indptr
=
indptr
.
expand
(
sizes
);
auto
dim
=
indptr
.
dim
()
-
1
;
CHECK_INPUT
(
src
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
CHECK_INPUT
(
src
.
size
(
dim
)
==
0
||
src
.
size
(
dim
)
==
indptr
.
size
(
dim
)
-
1
);
src
=
src
.
contiguous
();
...
...
@@ -237,14 +240,20 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
if
(
i
!=
dim
)
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
if
(
src
.
numel
()
>
0
)
{
auto
d_size
=
indptr
.
flatten
()[
-
1
].
data_ptr
<
int64_t
>
();
auto
h_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
h_size
,
d_size
,
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
dim
]
=
*
h_size
;
}
else
sizes
[
dim
]
=
0
;
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
if
(
src
.
numel
()
==
0
)
return
out
;
auto
N
=
src
.
size
(
dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
src
.
numel
()
/
N
;
auto
E
=
out
.
size
(
dim
);
...
...
csrc/segment_csr.cpp
View file @
88dd792e
...
...
@@ -82,6 +82,7 @@ public:
auto
indptr
=
saved
[
0
];
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
auto
grad_in
=
torch
::
empty
(
src_shape
,
grad_out
.
options
());
if
(
grad_in
.
numel
()
>
0
)
{
gather_csr_fw
(
grad_out
,
indptr
,
grad_in
);
auto
indptr1
=
indptr
.
narrow
(
-
1
,
0
,
indptr
.
size
(
-
1
)
-
1
);
auto
indptr2
=
indptr
.
narrow
(
-
1
,
1
,
indptr
.
size
(
-
1
)
-
1
);
...
...
@@ -90,6 +91,7 @@ public:
for
(
auto
i
=
0
;
i
<
grad_out
.
dim
()
-
indptr
.
dim
();
i
++
)
count
=
count
.
unsqueeze
(
-
1
);
grad_in
.
div_
(
count
);
}
return
{
grad_in
,
Variable
(),
Variable
()};
}
};
...
...
test/test_zero_tensors.py
View file @
88dd792e
from
itertools
import
product
import
pytest
import
torch
from
torch_scatter
import
scatter
from
torch_scatter
import
scatter
,
segment_coo
,
gather_coo
from
torch_scatter
import
segment_csr
,
gather_csr
from
.utils
import
reductions
,
tensor
,
grad_dtypes
,
devices
@
pytest
.
mark
.
parametrize
(
'reduce,dtype,device'
,
product
(
reductions
,
grad_dtypes
,
devices
))
def
test_zero_elements
(
reduce
,
dtype
,
device
):
x
=
torch
.
randn
(
0
,
0
,
0
,
16
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
index
=
tensor
([],
torch
.
long
,
device
)
indptr
=
tensor
([],
torch
.
long
,
device
)
out
=
scatter
(
x
,
index
,
dim
=
0
,
dim_size
=
0
,
reduce
=
reduce
)
out
.
backward
(
torch
.
randn_like
(
out
))
assert
out
.
size
()
==
(
0
,
0
,
0
,
16
)
out
=
segment_coo
(
x
,
index
,
dim_size
=
0
,
reduce
=
reduce
)
out
.
backward
(
torch
.
randn_like
(
out
))
assert
out
.
size
()
==
(
0
,
0
,
0
,
16
)
out
=
gather_coo
(
x
,
index
)
out
.
backward
(
torch
.
randn_like
(
out
))
assert
out
.
size
()
==
(
0
,
0
,
0
,
16
)
def
test_zero_elements
():
x
=
torch
.
randn
(
0
,
16
)
index
=
torch
.
tensor
([]).
view
(
0
,
16
)
print
(
x
)
print
(
index
)
out
=
segment_csr
(
x
,
indptr
,
reduce
=
reduce
)
out
.
backward
(
torch
.
randn_like
(
out
))
assert
out
.
size
()
==
(
0
,
0
,
0
,
16
)
scatter
(
x
,
index
,
dim
=
0
,
dim_size
=
0
,
reduce
=
"add"
)
out
=
gather_csr
(
x
,
indptr
)
out
.
backward
(
torch
.
randn_like
(
out
))
assert
out
.
size
()
==
(
0
,
0
,
0
,
16
)
torch_scatter/scatter.py
View file @
88dd792e
...
...
@@ -12,12 +12,6 @@ try:
except
OSError
:
warnings
.
warn
(
'Failed to load `scatter` binaries.'
)
def
scatter_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
raise
ImportError
return
src
def
scatter_with_arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
]
...
...
@@ -25,7 +19,6 @@ except OSError:
raise
ImportError
return
src
,
index
torch
.
ops
.
torch_scatter
.
scatter_mean
=
scatter_placeholder
torch
.
ops
.
torch_scatter
.
scatter_min
=
scatter_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
scatter_max
=
scatter_with_arg_placeholder
...
...
@@ -37,11 +30,13 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
index
=
broadcast
(
index
,
src
,
dim
)
if
out
is
None
:
size
=
src
.
size
()
if
dim_size
is
None
:
size
[
dim
]
=
int
(
index
.
max
())
+
1
else
:
if
dim_size
is
not
None
:
size
[
dim
]
=
dim_size
out
=
src
.
new_zeros
(
size
)
elif
index
.
numel
()
==
0
:
size
[
dim
]
=
0
else
:
size
[
dim
]
=
int
(
index
.
max
())
+
1
out
=
torch
.
zeros
(
size
,
dtype
=
src
.
dtype
,
device
=
src
.
device
)
return
out
.
scatter_add_
(
dim
,
index
,
src
)
else
:
return
out
.
scatter_add_
(
dim
,
index
,
src
)
...
...
@@ -58,7 +53,22 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
def
scatter_mean
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
scatter_mean
(
src
,
index
,
dim
,
out
,
dim_size
)
out
=
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
dim_size
=
out
.
size
(
dim
)
index_dim
=
dim
if
index_dim
<
0
:
index_dim
=
index_dim
+
src
.
dim
()
if
index
.
dim
()
<=
dim
:
index_dim
=
index
.
dim
()
-
1
ones
=
torch
.
ones
(
index
.
size
(),
dtype
=
src
.
dtype
,
device
=
src
.
device
)
count
=
scatter_sum
(
ones
,
index
,
index_dim
,
None
,
dim_size
)
count
.
clamp_
(
1
)
count
=
broadcast
(
count
,
out
,
dim
)
out
.
div_
(
count
)
return
out
@
torch
.
jit
.
script
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment