Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
2743b291
Commit
2743b291
authored
Jan 10, 2020
by
rusty1s
Browse files
basic segment tests
parent
34045b9a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
213 additions
and
103 deletions
+213
-103
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+1
-1
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+45
-55
test/test_segment.py
test/test_segment.py
+163
-43
torch_scatter/segment.py
torch_scatter/segment.py
+4
-4
No files found.
benchmark/scatter_segment.py
View file @
2743b291
...
...
@@ -79,7 +79,7 @@ def correctness(dataset):
out2
,
_
=
segment_coo
(
x
,
row
,
reduce
=
'max'
)
out3
,
_
=
segment_csr
(
x
,
rowptr
,
reduce
=
'max'
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
#
assert torch.allclose(out1, out2, atol=1e-4)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
except
RuntimeError
:
...
...
cuda/segment_kernel.cu
View file @
2743b291
...
...
@@ -12,6 +12,7 @@
#define FULL_MASK 0xffffffff
enum
ReductionType
{
ADD
,
MEAN
,
MIN
,
MAX
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
...
...
@@ -204,22 +205,6 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
if
(
reduce
==
"any"
)
{
auto
index
=
indptr
.
narrow
(
reduce_dim
,
0
,
indptr
.
size
(
reduce_dim
)
-
1
);
auto
index2
=
indptr
.
narrow
(
reduce_dim
,
1
,
indptr
.
size
(
reduce_dim
)
-
1
);
auto
mask
=
(
index2
-
index
)
==
0
;
for
(
int
i
=
reduce_dim
+
1
;
i
<
src
.
dim
();
i
++
)
{
index
=
index
.
unsqueeze
(
-
1
);
mask
=
mask
.
unsqueeze
(
-
1
);
}
at
::
gather_out
(
out
,
src
,
reduce_dim
,
index
.
expand
(
out
.
sizes
()));
out
.
masked_fill_
(
mask
.
expand
(
out
.
sizes
()),
0
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
auto
N
=
out
.
size
(
reduce_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
out
.
numel
()
/
N
;
auto
E
=
src
.
size
(
reduce_dim
);
...
...
@@ -258,12 +243,13 @@ segment_coo_kernel(const scalar_t *src_data,
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
lane_idx
=
row_idx
&
(
32
-
1
);
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
out_idx
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
]
)
*
N
+
idx
;
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
...
...
@@ -272,16 +258,18 @@ segment_coo_kernel(const scalar_t *src_data,
// Parallel reduction inside a single warp.
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
if
(
lane_idx
>=
i
&&
row_idx
/
D
==
(
row_idx
-
i
)
/
D
)
{
assert
(
idx
>=
next_idx
);
if
(
lane_idx
>=
i
&&
idx
==
next_idx
)
if
(
idx
==
next_idx
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
);
}
}
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
if
(
lane_idx
==
32
-
1
||
row_idx
/
D
!=
(
row_idx
+
1
)
/
D
||
idx
!=
next_idx
)
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
out_idx
,
val
);
}
}
}
template
<
typename
scalar_t
>
...
...
@@ -291,16 +279,17 @@ __global__ void segment_coo_arg_kernel(
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
N
)
{
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
];
int
out_idx
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
]
)
*
N
+
idx
;
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
row_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
]
;
arg_out_data
[
out_idx
]
=
row_idx
%
D
;
}
}
...
...
@@ -314,38 +303,44 @@ __global__ void segment_coo_broadcast_kernel(
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int
row_start
=
(
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
)
*
TB
;
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
int
E_1
=
E
/
D
;
int
E_2
=
D
+
D
%
TB
;
int
row_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row_start
<
E
&&
col_idx
<
K
)
{
int
dim_start
=
(
row_idx
*
TB
)
/
E_2
;
int
row_start
=
(
row_idx
*
TB
)
%
E_2
;
if
(
dim_start
<
E_1
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_start
,
index_info
);
int
out_
idx
=
(
row_start
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
;
dim_start
*
D
+
row_start
,
index_info
);
int
idx
1
=
__ldg
(
index_info
.
data
+
offset
),
idx2
;
int
idx1
=
__ldg
(
index_info
.
data
+
offset
);
scalar_t
val
=
src_data
[
K
*
row_start
+
col_idx
];
scalar_t
val
=
src_data
[
K
*
(
dim_start
*
D
+
row_start
)
+
col_idx
];
#pragma unroll
for
(
int
i
=
1
;
i
<
TB
;
i
++
)
{
if
(
row_start
+
i
>=
E
)
if
(
row_start
+
i
>=
D
)
break
;
int
idx2
=
__ldg
(
index_info
.
data
+
offset
+
idx2
=
__ldg
(
index_info
.
data
+
offset
+
i
*
index_info
.
strides
[
index_info
.
dims
-
1
]);
assert
(
idx1
<=
idx2
);
if
(
idx1
==
idx2
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
]);
&
val
,
src_data
[
K
*
(
dim_start
*
D
+
row_start
+
i
)
+
col_idx
]);
}
else
{
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
out_idx
+
idx1
)
*
K
+
col_idx
,
val
);
val
=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
out_data
+
(
dim_start
*
N
+
idx1
)
*
K
+
col_idx
,
val
);
val
=
src_data
[
K
*
(
dim_start
*
D
+
row_start
+
i
)
+
col_idx
];
}
idx1
=
idx2
;
}
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
out_idx
+
idx1
)
*
K
+
col_idx
,
val
);
out_data
+
(
dim_start
*
N
+
idx1
)
*
K
+
col_idx
,
val
);
}
}
...
...
@@ -358,18 +353,17 @@ __global__ void segment_coo_arg_broadcast_kernel(
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
K
;
int
col_idx
=
thread_idx
%
K
;
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
if
(
row_idx
<
E
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
__ldg
(
index_info
.
data
+
offset
);
int
out_idx
=
((
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
+
idx
)
*
K
+
col_idx
;
int
out_idx
=
((
row_idx
/
D
)
*
N
+
idx
)
*
K
+
col_idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
thread_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
]
;
arg_out_data
[
out_idx
]
=
row_idx
%
D
;
}
}
...
...
@@ -395,15 +389,9 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
if
(
reduce
==
"any"
)
{
for
(
int
i
=
reduce_dim
+
1
;
i
<
src
.
dim
();
i
++
)
{
index
=
index
.
unsqueeze
(
-
1
);
}
out
.
scatter_
(
reduce_dim
,
index
.
expand
(
src
.
sizes
()),
src
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
auto
E
=
index
.
numel
();
auto
E_2
=
index
.
size
(
reduce_dim
);
auto
E_1
=
index
.
numel
()
/
E_2
;
auto
K
=
src
.
numel
()
/
E
;
auto
N
=
out
.
size
(
reduce_dim
);
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
...
...
@@ -421,20 +409,22 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
out_data
,
E
,
N
);
}
else
if
(
avg_len
<=
8
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
4
>
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
<<<
dim3
((
E_1
*
((
E_2
+
3
)
/
4
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
if
(
avg_len
<=
16
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
8
>
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
<<<
dim3
((
E_1
*
((
E_2
+
7
)
/
8
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
if
(
avg_len
<=
32
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
16
>
<<<
dim3
((
(
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
))
,
(
K
+
31
)
/
32
),
<<<
dim3
((
E_1
*
((
E_2
+
15
)
/
16
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
32
>
<<<
dim3
((
(
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
))
,
(
K
+
31
)
/
32
),
<<<
dim3
((
E_1
*
((
E_2
+
31
)
/
32
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
...
...
test/test_segment.py
View file @
2743b291
...
...
@@ -3,54 +3,174 @@ from itertools import product
import
pytest
import
torch
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
scatter_max
from
.utils
import
tensor
reductions
=
[
'add'
,
'mean'
,
'min'
,
'max'
]
dtypes
=
[
torch
.
float
]
devices
=
[
torch
.
device
(
'cuda'
)]
tests
=
[
{
'src'
:
[
1
,
2
,
3
,
4
,
5
,
6
],
'index'
:
[
0
,
0
,
1
,
1
,
1
,
3
],
'indptr'
:
[
0
,
2
,
5
,
5
,
6
],
'add'
:
[
3
,
12
,
0
,
6
],
'mean'
:
[
1.5
,
4
,
0
,
6
],
'min'
:
[
1
,
3
,
0
,
6
],
'arg_min'
:
[
0
,
2
,
6
,
5
],
'max'
:
[
2
,
5
,
0
,
6
],
'arg_max'
:
[
1
,
4
,
6
,
5
],
},
{
'src'
:
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
'index'
:
[
0
,
0
,
1
,
1
,
1
,
3
],
'indptr'
:
[
0
,
2
,
5
,
5
,
6
],
'add'
:
[[
4
,
6
],
[
21
,
24
],
[
0
,
0
],
[
11
,
12
]],
'mean'
:
[[
2
,
3
],
[
7
,
8
],
[
0
,
0
],
[
11
,
12
]],
'min'
:
[[
1
,
2
],
[
5
,
6
],
[
0
,
0
],
[
11
,
12
]],
'arg_min'
:
[[
0
,
0
],
[
2
,
2
],
[
6
,
6
],
[
5
,
5
]],
'max'
:
[[
3
,
4
],
[
9
,
10
],
[
0
,
0
],
[
11
,
12
]],
'arg_max'
:
[[
1
,
1
],
[
4
,
4
],
[
6
,
6
],
[
5
,
5
]],
},
{
'src'
:
[[
1
,
3
,
5
,
7
,
9
,
11
],
[
2
,
4
,
6
,
8
,
10
,
12
]],
'index'
:
[[
0
,
0
,
1
,
1
,
1
,
3
],
[
0
,
0
,
0
,
1
,
1
,
2
]],
'indptr'
:
[[
0
,
2
,
5
,
5
,
6
],
[
0
,
3
,
5
,
6
,
6
]],
'add'
:
[[
4
,
21
,
0
,
11
],
[
12
,
18
,
12
,
0
]],
'mean'
:
[[
2
,
7
,
0
,
11
],
[
4
,
9
,
12
,
0
]],
'min'
:
[[
1
,
5
,
0
,
11
],
[
2
,
8
,
12
,
0
]],
'arg_min'
:
[[
0
,
2
,
6
,
5
],
[
0
,
3
,
5
,
6
]],
'max'
:
[[
3
,
9
,
0
,
11
],
[
6
,
10
,
12
,
0
]],
'arg_max'
:
[[
1
,
4
,
6
,
5
],
[
2
,
4
,
5
,
6
]],
},
{
'src'
:
[[[
1
,
3
,
5
],
[
2
,
4
,
6
]],
[[
7
,
9
,
11
],
[
8
,
10
,
12
]]],
'index'
:
[[[
0
,
0
,
1
],
[
0
,
2
,
2
]],
[[
0
,
0
,
1
],
[
0
,
2
,
2
]]],
'indptr'
:
[[[
0
,
2
,
3
,
3
],
[
0
,
1
,
1
,
3
]],
[[
0
,
2
,
3
,
3
],
[
0
,
1
,
1
,
3
]]],
'add'
:
[[[
4
,
5
,
0
],
[
2
,
0
,
10
]],
[[
16
,
11
,
0
],
[
8
,
0
,
22
]]],
'mean'
:
[[[
2
,
5
,
0
],
[
2
,
0
,
5
]],
[[
8
,
11
,
0
],
[
8
,
0
,
11
]]],
'min'
:
[[[
1
,
5
,
0
],
[
2
,
0
,
4
]],
[[
7
,
11
,
0
],
[
8
,
0
,
10
]]],
'arg_min'
:
[[[
0
,
2
,
3
],
[
0
,
3
,
1
]],
[[
0
,
2
,
3
],
[
0
,
3
,
1
]]],
'max'
:
[[[
3
,
5
,
0
],
[
2
,
0
,
6
]],
[[
9
,
11
,
0
],
[
8
,
0
,
12
]]],
'arg_max'
:
[[[
1
,
2
,
3
],
[
0
,
3
,
2
]],
[[
1
,
2
,
3
],
[
0
,
3
,
2
]]],
},
{
'src'
:
[[
1
,
3
],
[
2
,
4
]],
'index'
:
[[
0
,
0
],
[
0
,
0
]],
'indptr'
:
[[
0
,
2
],
[
0
,
2
]],
'add'
:
[[
4
],
[
6
]],
'mean'
:
[[
2
],
[
3
]],
'min'
:
[[
1
],
[
2
]],
'arg_min'
:
[[
0
],
[
0
]],
'max'
:
[[
3
],
[
4
]],
'arg_max'
:
[[
1
],
[
1
]],
},
{
'src'
:
[[[
1
,
1
],
[
3
,
3
]],
[[
2
,
2
],
[
4
,
4
]]],
'index'
:
[[
0
,
0
],
[
0
,
0
]],
'indptr'
:
[[
0
,
2
],
[
0
,
2
]],
'add'
:
[[[
4
,
4
]],
[[
6
,
6
]]],
'mean'
:
[[[
2
,
2
]],
[[
3
,
3
]]],
'min'
:
[[[
1
,
1
]],
[[
2
,
2
]]],
'arg_min'
:
[[[
0
,
0
]],
[[
0
,
0
]]],
'max'
:
[[[
3
,
3
]],
[[
4
,
4
]]],
'arg_max'
:
[[[
1
,
1
]],
[[
1
,
1
]]],
},
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
def
test_segment
(
test
,
reduce
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
out
=
segment_coo
(
src
,
index
,
reduce
=
reduce
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
reduce
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward
(
dtype
,
device
):
src
=
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
dtype
,
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
# src = tensor([-1, -2, -3, -4, -5, -6], dtype, device)
src
.
requires_grad_
()
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
,
arg
=
scatter_max
(
src
,
index
,
dim
=
0
)
print
(
'SCA'
)
print
(
out
)
print
(
arg
)
# print('SCA', out)
# grad_out = torch.randn_like(out)
# print(grad_out)
# out.backward(grad_out)
# print(src.grad)
# src.grad = None
out
,
arg
=
segment_coo
(
src
,
index
,
reduce
=
'max'
)
print
(
'COO'
)
print
(
out
)
print
(
arg
)
out
,
arg
=
segment_csr
(
src
,
indptr
,
reduce
=
'max'
)
print
(
'CSR'
)
print
(
out
)
print
(
arg
)
# out.backward(grad_out)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
# out = segment_coo(src, index, reduce='max')[0]
# print('COO', out)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
def
test_segment_out
(
test
,
reduce
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
size
=
list
(
src
.
size
())
size
[
indptr
.
dim
()
-
1
]
=
indptr
.
size
(
-
1
)
-
1
out
=
src
.
new_full
(
size
,
-
2
)
# Pre-defined `out` values shouldn't do anything.
out
=
segment_csr
(
src
,
indptr
,
out
,
reduce
=
reduce
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
out
.
fill_
(
-
2
)
out
=
segment_coo
(
src
,
index
,
out
,
reduce
=
reduce
)
out
=
out
[
0
]
if
isinstance
(
out
,
tuple
)
else
out
if
reduce
==
'add'
:
expected
=
expected
-
2
elif
reduce
==
'mean'
:
expected
=
out
# We can not really test this here.
elif
reduce
==
'min'
:
expected
=
expected
.
fill_
(
-
2
)
elif
reduce
==
'max'
:
expected
[
expected
==
0
]
=
-
2
else
:
raise
ValueError
assert
torch
.
all
(
out
==
expected
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
def
test_non_contiguous_segment
(
test
,
reduce
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
if
src
.
dim
()
>
1
:
src
=
src
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
if
index
.
dim
()
>
1
:
index
=
index
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
if
indptr
.
dim
()
>
1
:
indptr
=
indptr
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
out
=
segment_coo
(
src
,
index
,
reduce
=
reduce
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
reduce
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
arg_expected
=
tensor
(
test
[
f
'arg_
{
reduce
}
'
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
torch_scatter/segment.py
View file @
2743b291
...
...
@@ -9,7 +9,7 @@ if torch.cuda.is_available():
class
SegmentCOO
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
src
,
index
,
out
,
dim_size
,
reduce
):
assert
reduce
in
[
'any'
,
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
if
out
is
not
None
:
ctx
.
mark_dirty
(
out
)
ctx
.
reduce
=
reduce
...
...
@@ -46,7 +46,7 @@ class SegmentCOO(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'any'
or
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'add'
:
grad_src
=
gather_cuda
.
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
...
...
@@ -70,7 +70,7 @@ class SegmentCOO(torch.autograd.Function):
class
SegmentCSR
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
src
,
indptr
,
out
,
reduce
):
assert
reduce
in
[
'any'
,
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
if
out
is
not
None
:
ctx
.
mark_dirty
(
out
)
...
...
@@ -87,7 +87,7 @@ class SegmentCSR(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'any'
or
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'add'
:
grad_src
=
gather_cuda
.
gather_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment