Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
cd84568b
Commit
cd84568b
authored
Jan 15, 2020
by
Koch
Browse files
fix: fix errors regarding Reducer functionalities in segment.cpp
parent
1eabf7f1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
15 deletions
+16
-15
cpu/segment.cpp
cpu/segment.cpp
+16
-15
No files found.
cpu/segment.cpp
View file @
cd84568b
...
@@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
...
@@ -11,23 +11,24 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
[&] { \
ReductionType REDUCE = ADD; \
if (reduce == "add") { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD;
\
REDUCE = ADD;
\
return __VA_ARGS__(); \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN;
\
REDUCE = MEAN;
\
return __VA_ARGS__(); \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN;
\
REDUCE = MIN;
\
return __VA_ARGS__(); \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX;
\
REDUCE = MAX;
\
return __VA_ARGS__(); \
return __VA_ARGS__(); \
} \
} \
}()
}()
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
struct
Reducer
{
template
<
typename
scalar_t
>
struct
Reducer
{
static
inline
scalar_t
init
()
{
static
inline
scalar_t
init
(
ReductionType
REDUCE
)
{
if
(
REDUCE
==
MIN
)
{
if
(
REDUCE
==
MIN
)
{
return
std
::
numeric_limits
<
scalar_t
>::
max
();
return
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
}
else
if
(
REDUCE
==
MAX
)
{
...
@@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -37,7 +38,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
}
}
static
inline
void
update
(
scalar_t
*
val
,
scalar_t
new_val
)
{
static
inline
void
update
(
ReductionType
REDUCE
,
scalar_t
*
val
,
scalar_t
new_val
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
...
@@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -46,7 +47,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
}
}
static
inline
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
static
inline
void
update
(
ReductionType
REDUCE
,
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
int64_t
new_arg
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
*
val
=
*
val
+
new_val
;
...
@@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -57,7 +58,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
}
}
static
inline
void
write
(
scalar_t
*
address
,
scalar_t
val
,
static
inline
void
write
(
ReductionType
REDUCE
,
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
if
(
REDUCE
==
ADD
)
{
if
(
REDUCE
==
ADD
)
{
*
address
=
val
;
*
address
=
val
;
...
@@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
...
@@ -136,16 +137,16 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
vals
[
k
]
=
Reducer
<
scalar_t
>::
init
(
REDUCE
);
}
}
for
(
int64_t
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int64_t
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
>::
update
(
REDUCE
,
&
vals
[
k
],
src_data
[
offset
+
e
*
K
+
k
],
&
args
[
k
],
e
);
&
vals
[
k
],
src_data
[
offset
+
e
*
K
+
k
],
&
args
[
k
],
e
);
}
}
}
}
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
n
*
K
+
k
,
vals
[
k
],
Reducer
<
scalar_t
>::
write
(
REDUCE
,
out_data
+
n
*
K
+
k
,
vals
[
k
],
arg_out_data
+
n
*
K
+
k
,
args
[
k
],
arg_out_data
+
n
*
K
+
k
,
args
[
k
],
row_end
-
row_start
);
row_end
-
row_start
);
}
}
...
@@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -214,13 +215,13 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
>::
update
(
REDUCE
,
&
vals
[
k
],
src_data
[
e_1
*
E_2
*
K
+
e_2
*
K
+
k
],
&
args
[
k
],
e_2
);
&
vals
[
k
],
src_data
[
e_1
*
E_2
*
K
+
e_2
*
K
+
k
],
&
args
[
k
],
e_2
);
}
}
if
(
e_2
==
E_2
-
1
)
{
if
(
e_2
==
E_2
-
1
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
Reducer
<
scalar_t
>::
write
(
REDUCE
,
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
e_2
+
1
-
row_start
);
e_2
+
1
-
row_start
);
...
@@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -231,7 +232,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if
(
idx
!=
next_idx
)
{
if
(
idx
!=
next_idx
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
Reducer
<
scalar_t
>::
write
(
REDUCE
,
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
e_2
+
1
-
row_start
);
e_2
+
1
-
row_start
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment