Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
0807f87f
Commit
0807f87f
authored
Jan 10, 2020
by
rusty1s
Browse files
all tests + segment_coo fixes
parent
a9f9266b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
146 additions
and
61 deletions
+146
-61
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+1
-1
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+16
-15
test/test_gather.py
test/test_gather.py
+101
-27
test/test_segment.py
test/test_segment.py
+27
-18
torch_scatter/segment.py
torch_scatter/segment.py
+1
-0
No files found.
benchmark/scatter_segment.py
View file @
0807f87f
...
@@ -79,7 +79,7 @@ def correctness(dataset):
...
@@ -79,7 +79,7 @@ def correctness(dataset):
out2
,
_
=
segment_coo
(
x
,
row
,
reduce
=
'max'
)
out2
,
_
=
segment_coo
(
x
,
row
,
reduce
=
'max'
)
out3
,
_
=
segment_csr
(
x
,
rowptr
,
reduce
=
'max'
)
out3
,
_
=
segment_csr
(
x
,
rowptr
,
reduce
=
'max'
)
#
assert torch.allclose(out1, out2, atol=1e-4)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
except
RuntimeError
:
except
RuntimeError
:
...
...
cuda/segment_kernel.cu
View file @
0807f87f
...
@@ -80,9 +80,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -80,9 +80,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
)
{
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
)
{
if
(
REDUCE
==
ADD
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MEAN
)
{
atomAdd
(
address
,
val
);
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MIN
&&
val
<
*
address
)
{
}
else
if
(
REDUCE
==
MIN
&&
val
<
*
address
)
{
atomMin
(
address
,
val
);
atomMin
(
address
,
val
);
...
@@ -108,15 +106,16 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -108,15 +106,16 @@ segment_csr_kernel(const scalar_t *src_data,
if
(
row_idx
<
N
)
{
if
(
row_idx
<
N
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
row_idx
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
64_t
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
int
64_t
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int64_t
arg
,
arg_tmp
;
int64_t
arg
,
arg_tmp
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
for
(
int64_t
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
offset
+
src_idx
],
&
arg
,
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
offset
+
src_idx
],
&
arg
,
src_idx
);
src_idx
);
}
}
...
@@ -154,15 +153,15 @@ __global__ void segment_csr_broadcast_kernel(
...
@@ -154,15 +153,15 @@ __global__ void segment_csr_broadcast_kernel(
if
(
thread_idx
<
N
*
K
)
{
if
(
thread_idx
<
N
*
K
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
row_idx
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
64_t
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
int
64_t
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int64_t
arg
;
int64_t
arg
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
*
K
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
*
K
;
for
(
int
src_idx
=
row_start
;
src_idx
<
row_end
;
src_idx
++
)
{
for
(
int
64_t
src_idx
=
row_start
;
src_idx
<
row_end
;
src_idx
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
offset
+
K
*
src_idx
+
lane_idx
],
&
arg
,
src_idx
);
&
val
,
src_data
[
offset
+
K
*
src_idx
+
lane_idx
],
&
arg
,
src_idx
);
}
}
...
@@ -253,7 +252,7 @@ segment_coo_kernel(const scalar_t *src_data,
...
@@ -253,7 +252,7 @@ segment_coo_kernel(const scalar_t *src_data,
if
(
row_idx
<
E
)
{
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
64_t
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
...
@@ -289,7 +288,7 @@ __global__ void segment_coo_arg_kernel(
...
@@ -289,7 +288,7 @@ __global__ void segment_coo_arg_kernel(
if
(
row_idx
<
E
)
{
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
];
int
64_t
idx
=
index_info
.
data
[
offset
];
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
...
@@ -310,7 +309,7 @@ __global__ void segment_coo_broadcast_kernel(
...
@@ -310,7 +309,7 @@ __global__ void segment_coo_broadcast_kernel(
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
int
E_1
=
E
/
D
;
int
E_1
=
E
/
D
;
int
E_2
=
D
+
D
%
TB
;
int
E_2
=
D
+
TB
-
(
D
%
TB
)
;
int
row_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
row_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -319,6 +318,7 @@ __global__ void segment_coo_broadcast_kernel(
...
@@ -319,6 +318,7 @@ __global__ void segment_coo_broadcast_kernel(
int
row_start
=
(
row_idx
*
TB
)
%
E_2
;
int
row_start
=
(
row_idx
*
TB
)
%
E_2
;
if
(
dim_start
<
E_1
&&
col_idx
<
K
)
{
if
(
dim_start
<
E_1
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
dim_start
*
D
+
row_start
,
index_info
);
dim_start
*
D
+
row_start
,
index_info
);
int
idx1
=
__ldg
(
index_info
.
data
+
offset
),
idx2
;
int
idx1
=
__ldg
(
index_info
.
data
+
offset
),
idx2
;
...
@@ -341,6 +341,7 @@ __global__ void segment_coo_broadcast_kernel(
...
@@ -341,6 +341,7 @@ __global__ void segment_coo_broadcast_kernel(
out_data
+
(
dim_start
*
N
+
idx1
)
*
K
+
col_idx
,
val
);
out_data
+
(
dim_start
*
N
+
idx1
)
*
K
+
col_idx
,
val
);
val
=
src_data
[
K
*
(
dim_start
*
D
+
row_start
+
i
)
+
col_idx
];
val
=
src_data
[
K
*
(
dim_start
*
D
+
row_start
+
i
)
+
col_idx
];
}
}
idx1
=
idx2
;
idx1
=
idx2
;
}
}
...
@@ -405,7 +406,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -405,7 +406,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto
E_1
=
index
.
numel
()
/
E_2
;
auto
E_1
=
index
.
numel
()
/
E_2
;
auto
K
=
src
.
numel
()
/
E
;
auto
K
=
src
.
numel
()
/
E
;
auto
N
=
out
.
size
(
reduce_dim
);
auto
N
=
out
.
size
(
reduce_dim
);
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
)
;
auto
avg_len
=
(
float
)
E_2
/
(
float
)
N
;
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
test/test_gather.py
View file @
0807f87f
...
@@ -2,6 +2,7 @@ from itertools import product
...
@@ -2,6 +2,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
torch_scatter
import
gather_coo
,
gather_csr
from
torch_scatter
import
gather_coo
,
gather_csr
from
.utils
import
tensor
from
.utils
import
tensor
...
@@ -9,38 +10,111 @@ from .utils import tensor
...
@@ -9,38 +10,111 @@ from .utils import tensor
dtypes
=
[
torch
.
float
]
dtypes
=
[
torch
.
float
]
devices
=
[
torch
.
device
(
'cuda'
)]
devices
=
[
torch
.
device
(
'cuda'
)]
tests
=
[
{
'src'
:
[
1
,
2
,
3
,
4
],
'index'
:
[
0
,
0
,
1
,
1
,
1
,
3
],
'indptr'
:
[
0
,
2
,
5
,
5
,
6
],
'expected'
:
[
1
,
1
,
2
,
2
,
2
,
4
],
},
{
'src'
:
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
'index'
:
[
0
,
0
,
1
,
1
,
1
,
3
],
'indptr'
:
[
0
,
2
,
5
,
5
,
6
],
'expected'
:
[[
1
,
2
],
[
1
,
2
],
[
3
,
4
],
[
3
,
4
],
[
3
,
4
],
[
7
,
8
]]
},
{
'src'
:
[[
1
,
3
,
5
,
7
],
[
2
,
4
,
6
,
8
]],
'index'
:
[[
0
,
0
,
1
,
1
,
1
,
3
],
[
0
,
0
,
0
,
1
,
1
,
2
]],
'indptr'
:
[[
0
,
2
,
5
,
5
,
6
],
[
0
,
3
,
5
,
6
,
6
]],
'expected'
:
[[
1
,
1
,
3
,
3
,
3
,
7
],
[
2
,
2
,
2
,
4
,
4
,
6
]],
},
{
'src'
:
[[[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
[[
7
,
9
],
[
10
,
11
],
[
12
,
13
]]],
'index'
:
[[
0
,
0
,
1
],
[
0
,
2
,
2
]],
'indptr'
:
[[
0
,
2
,
3
,
3
],
[
0
,
1
,
1
,
3
]],
'expected'
:
[[[
1
,
2
],
[
1
,
2
],
[
3
,
4
]],
[[
7
,
9
],
[
12
,
13
],
[
12
,
13
]]],
},
{
'src'
:
[[
1
],
[
2
]],
'index'
:
[[
0
,
0
],
[
0
,
0
]],
'indptr'
:
[[
0
,
2
],
[
0
,
2
]],
'expected'
:
[[
1
,
1
],
[
2
,
2
]],
},
{
'src'
:
[[[
1
,
1
]],
[[
2
,
2
]]],
'index'
:
[[
0
,
0
],
[
0
,
0
]],
'indptr'
:
[[
0
,
2
],
[
0
,
2
]],
'expected'
:
[[[
1
,
1
],
[
1
,
1
]],
[[
2
,
2
],
[
2
,
2
]]],
},
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_forward
(
dtype
,
device
):
def
test_forward
(
test
,
dtype
,
device
):
src
=
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
,
device
)
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
src
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
src
.
requires_grad_
()
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
'expected'
],
dtype
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
src
.
index_select
(
0
,
index
)
out
=
gather_coo
(
src
,
index
)
grad_out
=
torch
.
randn_like
(
out
)
assert
torch
.
all
(
out
==
expected
)
out
.
backward
(
grad_out
)
print
(
'EXPECTED'
)
print
(
out
)
print
(
src
.
grad
)
src
.
grad
=
None
out
=
gather_csr
(
src
,
indptr
)
out
=
gather_csr
(
src
,
indptr
)
out
.
backward
(
grad_out
)
assert
torch
.
all
(
out
==
expected
)
print
(
'CSR'
)
print
(
out
)
print
(
src
.
grad
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
# print('CSR', out)
@
pytest
.
mark
.
parametrize
(
'test,device'
,
product
(
tests
,
devices
))
def
test_backward
(
test
,
device
):
src
=
tensor
(
test
[
'src'
],
torch
.
double
,
device
)
src
.
requires_grad_
()
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
assert
gradcheck
(
gather_coo
,
(
src
,
index
,
None
))
is
True
assert
gradcheck
(
gather_csr
,
(
src
,
indptr
,
None
))
is
True
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_segment_out
(
test
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
'expected'
],
dtype
,
device
)
size
=
list
(
src
.
size
())
size
[
index
.
dim
()
-
1
]
=
index
.
size
(
-
1
)
out
=
src
.
new_full
(
size
,
-
2
)
gather_coo
(
src
,
index
,
out
)
assert
torch
.
all
(
out
==
expected
)
out
.
fill_
(
-
2
)
gather_csr
(
src
,
indptr
,
out
)
assert
torch
.
all
(
out
==
expected
)
# out = gather_coo(src, index)
# print('COO', out)
# print('Expected', out)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
src
.
grad
=
None
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_non_contiguous_segment
(
test
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
'expected'
],
dtype
,
device
)
if
src
.
dim
()
>
1
:
src
=
src
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
if
index
.
dim
()
>
1
:
index
=
index
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
if
indptr
.
dim
()
>
1
:
indptr
=
indptr
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
out
=
gather_coo
(
src
,
index
)
out
=
gather_coo
(
src
,
index
)
out
.
backward
(
grad_out
)
assert
torch
.
all
(
out
==
expected
)
print
(
'COO'
)
print
(
out
)
out
=
gather_csr
(
src
,
indptr
)
print
(
src
.
gra
d
)
assert
torch
.
all
(
out
==
expecte
d
)
test/test_segment.py
View file @
0807f87f
...
@@ -2,11 +2,13 @@ from itertools import product
...
@@ -2,11 +2,13 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
segment_coo
,
segment_csr
from
.utils
import
tensor
from
.utils
import
tensor
reductions
=
[
'add'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'add'
,
'mean'
,
'min'
,
'max'
]
grad_reductions
=
[
'add'
,
'mean'
]
dtypes
=
[
torch
.
float
]
dtypes
=
[
torch
.
float
]
devices
=
[
torch
.
device
(
'cuda'
)]
devices
=
[
torch
.
device
(
'cuda'
)]
...
@@ -46,15 +48,15 @@ tests = [
...
@@ -46,15 +48,15 @@ tests = [
'arg_max'
:
[[
1
,
4
,
6
,
5
],
[
2
,
4
,
5
,
6
]],
'arg_max'
:
[[
1
,
4
,
6
,
5
],
[
2
,
4
,
5
,
6
]],
},
},
{
{
'src'
:
[[[
1
,
3
,
5
],
[
2
,
4
,
6
]],
[[
7
,
9
,
11
],
[
8
,
10
,
1
2
]]],
'src'
:
[[[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
[[
7
,
9
],
[
10
,
11
],
[
12
,
1
3
]]],
'index'
:
[[
[
0
,
0
,
1
],
[
0
,
2
,
2
]],
[[
0
,
0
,
1
],
[
0
,
2
,
2
]]],
'index'
:
[[
0
,
0
,
1
],
[
0
,
2
,
2
]],
'indptr'
:
[[
[
0
,
2
,
3
,
3
],
[
0
,
1
,
1
,
3
]],
[[
0
,
2
,
3
,
3
],
[
0
,
1
,
1
,
3
]]],
'indptr'
:
[[
0
,
2
,
3
,
3
],
[
0
,
1
,
1
,
3
]],
'add'
:
[[[
4
,
5
,
0
],
[
2
,
0
,
1
0
]],
[[
16
,
11
,
0
],
[
8
,
0
,
22
]]],
'add'
:
[[[
4
,
6
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
]
,
[
22
,
24
]]],
'mean'
:
[[[
2
,
5
,
0
],
[
2
,
0
,
5
]],
[[
8
,
11
,
0
],
[
8
,
0
,
11
]]],
'mean'
:
[[[
2
,
3
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
]
,
[
11
,
12
]]],
'min'
:
[[[
1
,
5
,
0
],
[
2
,
0
,
4
]],
[[
7
,
11
,
0
],
[
8
,
0
,
1
0
]]],
'min'
:
[[[
1
,
2
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
1
0
,
1
1
]]],
'arg_min'
:
[[[
0
,
2
,
3
],
[
0
,
3
,
1
]],
[[
0
,
2
,
3
],
[
0
,
3
,
1
]]],
'arg_min'
:
[[[
0
,
0
],
[
2
,
2
],
[
3
,
3
]],
[[
0
,
0
],
[
3
,
3
],
[
1
,
1
]]],
'max'
:
[[[
3
,
5
,
0
],
[
2
,
0
,
6
]],
[[
9
,
11
,
0
],
[
8
,
0
,
12
]]],
'max'
:
[[[
3
,
4
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
]
,
[
12
,
13
]]],
'arg_max'
:
[[[
1
,
2
,
3
],
[
0
,
3
,
2
]],
[[
1
,
2
,
3
],
[
0
,
3
,
2
]]],
'arg_max'
:
[[[
1
,
1
],
[
2
,
2
],
[
3
,
3
]],
[[
0
,
0
],
[
3
,
3
],
[
2
,
2
]]],
},
},
{
{
'src'
:
[[
1
,
3
],
[
2
,
4
]],
'src'
:
[[
1
,
3
],
[
2
,
4
]],
...
@@ -84,7 +86,7 @@ tests = [
...
@@ -84,7 +86,7 @@ tests = [
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
product
(
tests
,
reductions
,
dtypes
,
devices
))
def
test_
segment
(
test
,
reduce
,
dtype
,
device
):
def
test_
forward
(
test
,
reduce
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
...
@@ -105,6 +107,19 @@ def test_segment(test, reduce, dtype, device):
...
@@ -105,6 +107,19 @@ def test_segment(test, reduce, dtype, device):
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,device'
,
product
(
tests
,
grad_reductions
,
devices
))
def
test_backward
(
test
,
reduce
,
device
):
src
=
tensor
(
test
[
'src'
],
torch
.
double
,
device
)
src
.
requires_grad_
()
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
assert
gradcheck
(
segment_coo
,
(
src
,
index
,
None
,
None
,
reduce
))
is
True
assert
gradcheck
(
segment_csr
,
(
src
,
indptr
,
None
,
reduce
))
is
True
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
product
(
tests
,
reductions
,
dtypes
,
devices
))
...
@@ -118,18 +133,12 @@ def test_segment_out(test, reduce, dtype, device):
...
@@ -118,18 +133,12 @@ def test_segment_out(test, reduce, dtype, device):
size
[
indptr
.
dim
()
-
1
]
=
indptr
.
size
(
-
1
)
-
1
size
[
indptr
.
dim
()
-
1
]
=
indptr
.
size
(
-
1
)
-
1
out
=
src
.
new_full
(
size
,
-
2
)
out
=
src
.
new_full
(
size
,
-
2
)
# Pre-defined `out` values shouldn't do anything.
segment_csr
(
src
,
indptr
,
out
,
reduce
=
reduce
)
out
=
segment_csr
(
src
,
indptr
,
out
,
reduce
=
reduce
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
out
.
fill_
(
-
2
)
out
.
fill_
(
-
2
)
out
=
segment_coo
(
src
,
index
,
out
,
reduce
=
reduce
)
segment_coo
(
src
,
index
,
out
,
reduce
=
reduce
)
out
=
out
[
0
]
if
isinstance
(
out
,
tuple
)
else
out
if
reduce
==
'add'
:
if
reduce
==
'add'
:
expected
=
expected
-
2
expected
=
expected
-
2
...
...
torch_scatter/segment.py
View file @
0807f87f
...
@@ -64,6 +64,7 @@ class SegmentCOO(torch.autograd.Function):
...
@@ -64,6 +64,7 @@ class SegmentCOO(torch.autograd.Function):
index
.
dim
()
-
1
,
arg_out
,
grad_out
)
index
.
dim
()
-
1
,
arg_out
,
grad_out
)
grad_src
=
grad_src
.
narrow
(
index
.
dim
()
-
1
,
0
,
grad_src
=
grad_src
.
narrow
(
index
.
dim
()
-
1
,
0
,
src_size
[
index
.
dim
()
-
1
]
-
1
)
src_size
[
index
.
dim
()
-
1
]
-
1
)
return
grad_src
,
None
,
None
,
None
,
None
return
grad_src
,
None
,
None
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment