Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
3cf59da2
Commit
3cf59da2
authored
Jan 21, 2020
by
rusty1s
Browse files
add to sum, REDUCE to template
parent
7aa701b1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
83 deletions
+88
-83
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+7
-10
cpu/segment.cpp
cpu/segment.cpp
+32
-33
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+23
-14
test/test_segment.py
test/test_segment.py
+9
-9
torch_scatter/gather.py
torch_scatter/gather.py
+2
-2
torch_scatter/segment.py
torch_scatter/segment.py
+15
-15
No files found.
benchmark/scatter_segment.py
View file @
3cf59da2
...
...
@@ -122,11 +122,11 @@ def timing(dataset):
avg_row_len
=
row
.
size
(
0
)
/
dim_size
def
sca_row
(
x
):
op
=
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)
op
=
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
scatter_
reduce
}
'
)
return
op
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
def
sca_col
(
x
):
op
=
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)
op
=
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
scatter_
reduce
}
'
)
return
op
(
x
,
row_perm
,
dim
=
0
,
dim_size
=
dim_size
)
def
seg_coo
(
x
):
...
...
@@ -136,10 +136,10 @@ def timing(dataset):
return
segment_csr
(
x
,
rowptr
,
reduce
=
args
.
reduce
)
def
dense1
(
x
):
return
getattr
(
torch
,
args
.
dense_
reduce
)(
x
,
dim
=-
2
)
return
getattr
(
torch
,
args
.
reduce
)(
x
,
dim
=-
2
)
def
dense2
(
x
):
return
getattr
(
torch
,
args
.
dense_
reduce
)(
x
,
dim
=-
1
)
return
getattr
(
torch
,
args
.
reduce
)(
x
,
dim
=-
1
)
t1
,
t2
,
t3
,
t4
,
t5
,
t6
=
[],
[],
[],
[],
[],
[]
...
...
@@ -204,15 +204,12 @@ def timing(dataset):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
,
choices
=
[
'add'
,
'mean'
,
'min'
,
'max'
])
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
,
choices
=
[
'sum'
,
'mean'
,
'min'
,
'max'
])
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
.
dense
_reduce
=
'
sum
'
if
args
.
reduce
==
'
add
'
else
args
.
reduce
args
.
scatter
_reduce
=
'
add
'
if
args
.
reduce
==
'
sum
'
else
args
.
reduce
iters
=
1
if
args
.
device
==
'cpu'
else
20
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
sizes
[:
3
]
if
args
.
device
==
'cpu'
else
sizes
...
...
cpu/segment.cpp
View file @
3cf59da2
...
...
@@ -7,28 +7,36 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
enum
ReductionType
{
ADD
,
MEAN
,
MIN
,
MAX
};
enum
ReductionType
{
SUM
,
MEAN
,
MIN
,
MAX
};
const
std
::
map
<
std
::
string
,
ReductionType
>
reduce2REDUCE
=
{
{
"sum"
,
SUM
},
{
"add"
,
SUM
},
{
"mean"
,
MEAN
},
{
"min"
,
MIN
},
{
"max"
,
MAX
},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
ReductionType REDUCE = ADD;
\
if (reduce == "add") {
\
REDUCE = ADD;
\
switch (reduce2REDUCE.at(reduce)) {
\
case SUM: {
\
const ReductionType REDUCE = SUM;
\
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
REDUCE = MEAN; \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
REDUCE = MIN; \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
REDUCE = MAX; \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template
<
typename
scalar_t
>
struct
Reducer
{
static
inline
scalar_t
init
(
ReductionType
REDUCE
)
{
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
struct
Reducer
{
static
inline
scalar_t
init
()
{
if
(
REDUCE
==
MIN
)
{
return
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
...
...
@@ -38,18 +46,9 @@ template <typename scalar_t> struct Reducer {
}
}
static
inline
void
update
(
ReductionType
REDUCE
,
scalar_t
*
val
,
scalar_t
new_val
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
}
}
static
inline
void
update
(
ReductionType
REDUCE
,
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
static
inline
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
if
(
REDUCE
==
SUM
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
...
...
@@ -58,9 +57,9 @@ template <typename scalar_t> struct Reducer {
}
}
static
inline
void
write
(
ReductionType
REDUCE
,
scalar_t
*
address
,
scalar_t
val
,
static
inline
void
write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
if
(
REDUCE
==
ADD
)
{
if
(
REDUCE
==
SUM
)
{
*
address
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
*
address
=
val
/
(
count
>
0
?
count
:
(
scalar_t
)
1
);
...
...
@@ -111,7 +110,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
if
(
reduce
2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -137,16 +136,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
vals
[
k
]
=
Reducer
<
scalar_t
>::
init
(
REDUCE
);
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
}
for
(
int64_t
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
>::
update
(
REDUCE
,
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
src_data
[
offset
+
e
*
K
+
k
],
&
args
[
k
],
e
);
}
}
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
>::
write
(
REDUCE
,
out_data
+
n
*
K
+
k
,
vals
[
k
],
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
n
*
K
+
k
,
vals
[
k
],
arg_out_data
+
n
*
K
+
k
,
args
[
k
],
row_end
-
row_start
);
}
...
...
@@ -183,7 +182,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
if
(
reduce
2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -215,13 +214,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
>::
update
(
REDUCE
,
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
src_data
[
e_1
*
E_2
*
K
+
e_2
*
K
+
k
],
&
args
[
k
],
e_2
);
}
if
(
e_2
==
E_2
-
1
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
>::
write
(
REDUCE
,
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
e_2
+
1
-
row_start
);
...
...
@@ -232,7 +231,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if
(
idx
!=
next_idx
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
>::
write
(
REDUCE
,
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
e_2
+
1
-
row_start
);
...
...
cuda/segment_kernel.cu
View file @
3cf59da2
...
...
@@ -11,23 +11,32 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
enum
ReductionType
{
ADD
,
MEAN
,
MIN
,
MAX
};
enum
ReductionType
{
SUM
,
MEAN
,
MIN
,
MAX
};
const
std
::
map
<
std
::
string
,
ReductionType
>
reduce2REDUCE
=
{
{
"sum"
,
SUM
},
{
"add"
,
SUM
},
{
"mean"
,
MEAN
},
{
"min"
,
MIN
},
{
"max"
,
MAX
},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
struct
Reducer
{
...
...
@@ -43,7 +52,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
if
(
REDUCE
==
SUM
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
...
...
@@ -53,7 +62,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
if
(
REDUCE
==
SUM
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
...
...
@@ -65,7 +74,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static
inline
__host__
__device__
void
write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
if
(
REDUCE
==
ADD
)
{
if
(
REDUCE
==
SUM
)
{
*
address
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
*
address
=
val
/
(
scalar_t
)
max
(
count
,
1
);
...
...
@@ -80,7 +89,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
if
(
REDUCE
==
SUM
||
REDUCE
==
MEAN
)
{
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MIN
&&
val
<
*
address
)
{
atomMin
(
address
,
val
);
...
...
@@ -204,7 +213,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
if
(
reduce
2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -396,7 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
if
(
reduce
2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -455,14 +464,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
});
});
if
(
reduce
==
"mean"
)
{
if
(
reduce
2REDUCE
.
at
(
reduce
)
==
MEAN
)
{
auto
sizes
=
index
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
out
.
size
(
reduce_dim
);
auto
count
=
at
::
zeros
(
sizes
,
out
.
options
());
AT_DISPATCH_ALL_TYPES
(
out
.
scalar_type
(),
"count_kernel"
,
[
&
]
{
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
segment_coo_kernel
<
scalar_t
,
ADD
,
false
>
segment_coo_kernel
<
scalar_t
,
SUM
,
false
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
nullptr
,
index_info
,
count_data
,
E
,
N
);
});
...
...
test/test_segment.py
View file @
3cf59da2
...
...
@@ -7,15 +7,15 @@ from torch_scatter import segment_coo, segment_csr
from
.utils
import
tensor
,
dtypes
,
devices
reductions
=
[
'
add
'
,
'mean'
,
'min'
,
'max'
]
grad_reductions
=
[
'
add
'
,
'mean'
]
reductions
=
[
'
sum
'
,
'mean'
,
'min'
,
'max'
]
grad_reductions
=
[
'
sum
'
,
'mean'
]
tests
=
[
{
'src'
:
[
1
,
2
,
3
,
4
,
5
,
6
],
'index'
:
[
0
,
0
,
1
,
1
,
1
,
3
],
'indptr'
:
[
0
,
2
,
5
,
5
,
6
],
'
add
'
:
[
3
,
12
,
0
,
6
],
'
sum
'
:
[
3
,
12
,
0
,
6
],
'mean'
:
[
1.5
,
4
,
0
,
6
],
'min'
:
[
1
,
3
,
0
,
6
],
'arg_min'
:
[
0
,
2
,
6
,
5
],
...
...
@@ -26,7 +26,7 @@ tests = [
'src'
:
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
'index'
:
[
0
,
0
,
1
,
1
,
1
,
3
],
'indptr'
:
[
0
,
2
,
5
,
5
,
6
],
'
add
'
:
[[
4
,
6
],
[
21
,
24
],
[
0
,
0
],
[
11
,
12
]],
'
sum
'
:
[[
4
,
6
],
[
21
,
24
],
[
0
,
0
],
[
11
,
12
]],
'mean'
:
[[
2
,
3
],
[
7
,
8
],
[
0
,
0
],
[
11
,
12
]],
'min'
:
[[
1
,
2
],
[
5
,
6
],
[
0
,
0
],
[
11
,
12
]],
'arg_min'
:
[[
0
,
0
],
[
2
,
2
],
[
6
,
6
],
[
5
,
5
]],
...
...
@@ -37,7 +37,7 @@ tests = [
'src'
:
[[
1
,
3
,
5
,
7
,
9
,
11
],
[
2
,
4
,
6
,
8
,
10
,
12
]],
'index'
:
[[
0
,
0
,
1
,
1
,
1
,
3
],
[
0
,
0
,
0
,
1
,
1
,
2
]],
'indptr'
:
[[
0
,
2
,
5
,
5
,
6
],
[
0
,
3
,
5
,
6
,
6
]],
'
add
'
:
[[
4
,
21
,
0
,
11
],
[
12
,
18
,
12
,
0
]],
'
sum
'
:
[[
4
,
21
,
0
,
11
],
[
12
,
18
,
12
,
0
]],
'mean'
:
[[
2
,
7
,
0
,
11
],
[
4
,
9
,
12
,
0
]],
'min'
:
[[
1
,
5
,
0
,
11
],
[
2
,
8
,
12
,
0
]],
'arg_min'
:
[[
0
,
2
,
6
,
5
],
[
0
,
3
,
5
,
6
]],
...
...
@@ -48,7 +48,7 @@ tests = [
'src'
:
[[[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
[[
7
,
9
],
[
10
,
11
],
[
12
,
13
]]],
'index'
:
[[
0
,
0
,
1
],
[
0
,
2
,
2
]],
'indptr'
:
[[
0
,
2
,
3
,
3
],
[
0
,
1
,
1
,
3
]],
'
add
'
:
[[[
4
,
6
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
22
,
24
]]],
'
sum
'
:
[[[
4
,
6
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
22
,
24
]]],
'mean'
:
[[[
2
,
3
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
11
,
12
]]],
'min'
:
[[[
1
,
2
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
10
,
11
]]],
'arg_min'
:
[[[
0
,
0
],
[
2
,
2
],
[
3
,
3
]],
[[
0
,
0
],
[
3
,
3
],
[
1
,
1
]]],
...
...
@@ -59,7 +59,7 @@ tests = [
'src'
:
[[
1
,
3
],
[
2
,
4
]],
'index'
:
[[
0
,
0
],
[
0
,
0
]],
'indptr'
:
[[
0
,
2
],
[
0
,
2
]],
'
add
'
:
[[
4
],
[
6
]],
'
sum
'
:
[[
4
],
[
6
]],
'mean'
:
[[
2
],
[
3
]],
'min'
:
[[
1
],
[
2
]],
'arg_min'
:
[[
0
],
[
0
]],
...
...
@@ -70,7 +70,7 @@ tests = [
'src'
:
[[[
1
,
1
],
[
3
,
3
]],
[[
2
,
2
],
[
4
,
4
]]],
'index'
:
[[
0
,
0
],
[
0
,
0
]],
'indptr'
:
[[
0
,
2
],
[
0
,
2
]],
'
add
'
:
[[[
4
,
4
]],
[[
6
,
6
]]],
'
sum
'
:
[[[
4
,
4
]],
[[
6
,
6
]]],
'mean'
:
[[[
2
,
2
]],
[[
3
,
3
]]],
'min'
:
[[[
1
,
1
]],
[[
2
,
2
]]],
'arg_min'
:
[[[
0
,
0
]],
[[
0
,
0
]]],
...
...
@@ -134,7 +134,7 @@ def test_segment_out(test, reduce, dtype, device):
segment_coo
(
src
,
index
,
out
,
reduce
=
reduce
)
if
reduce
==
'
add
'
:
if
reduce
==
'
sum
'
:
expected
=
expected
-
2
elif
reduce
==
'mean'
:
expected
=
out
# We can not really test this here.
...
...
torch_scatter/gather.py
View file @
3cf59da2
...
...
@@ -31,7 +31,7 @@ class GatherCOO(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad_src
,
_
=
seg
(
grad_out
.
is_cuda
).
segment_coo
(
grad_out
,
index
,
grad_out
.
new_zeros
(
src_size
),
'
add
'
)
grad_out
,
index
,
grad_out
.
new_zeros
(
src_size
),
'
sum
'
)
return
grad_src
,
None
,
None
...
...
@@ -53,7 +53,7 @@ class GatherCSR(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad_src
,
_
=
seg
(
grad_out
.
is_cuda
).
segment_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
),
'
add
'
)
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
),
'
sum
'
)
return
grad_src
,
None
,
None
...
...
torch_scatter/segment.py
View file @
3cf59da2
...
...
@@ -18,7 +18,7 @@ def gat(is_cuda):
class
SegmentCOO
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
src
,
index
,
out
,
dim_size
,
reduce
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
if
out
is
not
None
:
ctx
.
mark_dirty
(
out
)
ctx
.
reduce
=
reduce
...
...
@@ -55,7 +55,7 @@ class SegmentCOO(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'sum'
or
ctx
.
reduce
==
'add'
:
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
...
...
@@ -68,7 +68,7 @@ class SegmentCOO(torch.autograd.Function):
size
[
-
1
]
=
grad_out
.
size
(
index
.
dim
()
-
1
)
count
=
segment_cpu
.
segment_coo
(
torch
.
ones_like
(
index
,
dtype
=
grad_out
.
dtype
),
index
,
grad_out
.
new_zeros
(
size
),
'
add
'
)[
0
].
clamp_
(
min
=
1
)
grad_out
.
new_zeros
(
size
),
'
sum
'
)[
0
].
clamp_
(
min
=
1
)
count
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
...
...
@@ -88,7 +88,7 @@ class SegmentCOO(torch.autograd.Function):
class
SegmentCSR
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
src
,
indptr
,
out
,
reduce
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
if
out
is
not
None
:
ctx
.
mark_dirty
(
out
)
...
...
@@ -105,7 +105,7 @@ class SegmentCSR(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'sum'
or
ctx
.
reduce
==
'add'
:
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
...
...
@@ -129,7 +129,7 @@ class SegmentCSR(torch.autograd.Function):
return
grad_src
,
None
,
None
,
None
def
segment_coo
(
src
,
index
,
out
=
None
,
dim_size
=
None
,
reduce
=
"
add
"
):
def
segment_coo
(
src
,
index
,
out
=
None
,
dim_size
=
None
,
reduce
=
"
sum
"
):
r
"""
|
...
...
@@ -158,7 +158,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="
add
"`, the operation
For one-dimensional tensors with :obj:`reduce="
sum
"`, the operation
computes
.. math::
...
...
@@ -196,9 +196,9 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
(default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"
add
"`,
reduce (string, optional): The reduce operation (:obj:`"
sum
"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"
add
"`)
(default: :obj:`"
sum
"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
...
...
@@ -210,7 +210,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="
add
")
out = segment_coo(src, index, reduce="
sum
")
print(out.size())
...
...
@@ -221,7 +221,7 @@ def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
return
SegmentCOO
.
apply
(
src
,
index
,
out
,
dim_size
,
reduce
)
def
segment_csr
(
src
,
indptr
,
out
=
None
,
reduce
=
"
add
"
):
def
segment_csr
(
src
,
indptr
,
out
=
None
,
reduce
=
"
sum
"
):
r
"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
...
...
@@ -242,7 +242,7 @@ def segment_csr(src, indptr, out=None, reduce="add"):
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="
add
"`, the operation
For one-dimensional tensors with :obj:`reduce="
sum
"`, the operation
computes
.. math::
...
...
@@ -267,9 +267,9 @@ def segment_csr(src, indptr, out=None, reduce="add"):
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"
add
"`,
reduce (string, optional): The reduce operation (:obj:`"
sum
"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"
add
"`)
(default: :obj:`"
sum
"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
...
...
@@ -281,7 +281,7 @@ def segment_csr(src, indptr, out=None, reduce="add"):
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="
add
")
out = segment_csr(src, indptr, reduce="
sum
")
print(out.size())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment