Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
0be33ffa
Commit
0be33ffa
authored
Feb 09, 2020
by
rusty1s
Browse files
potential windows fix
parent
feca30d1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
24 deletions
+25
-24
csrc/cpu/reducer.h
csrc/cpu/reducer.h
+7
-6
csrc/cpu/scatter_cpu.cpp
csrc/cpu/scatter_cpu.cpp
+4
-4
csrc/cpu/segment_coo_cpu.cpp
csrc/cpu/segment_coo_cpu.cpp
+8
-8
csrc/cpu/segment_csr_cpu.cpp
csrc/cpu/segment_csr_cpu.cpp
+6
-6
No files found.
csrc/cpu/reducer.h
View file @
0be33ffa
...
...
@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
} \
}()
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
struct
Reducer
{
static
inline
scalar_t
init
()
{
template
<
typename
scalar_t
>
struct
Reducer
{
static
inline
scalar_t
init
(
ReductionType
REDUCE
)
{
if
(
REDUCE
==
MUL
||
REDUCE
==
DIV
)
return
(
scalar_t
)
1
;
else
if
(
REDUCE
==
MIN
)
...
...
@@ -52,8 +52,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
return
(
scalar_t
)
0
;
}
static
inline
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
static
inline
void
update
(
ReductionType
REDUCE
,
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
if
(
REDUCE
==
SUM
||
REDUCE
==
MEAN
)
*
val
=
*
val
+
new_val
;
else
if
(
REDUCE
==
MUL
)
...
...
@@ -67,8 +67,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static
inline
void
write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
static
inline
void
write
(
ReductionType
REDUCE
,
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
if
(
REDUCE
==
SUM
||
REDUCE
==
MUL
||
REDUCE
==
DIV
)
*
address
=
val
;
else
if
(
REDUCE
==
MEAN
)
...
...
csrc/cpu/scatter_cpu.cpp
View file @
0be33ffa
...
...
@@ -61,22 +61,22 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
int64_t
i
,
idx
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
if
(
!
optional_out
.
has_value
())
out
.
fill_
(
Reducer
<
scalar_t
,
REDUCE
>::
init
());
out
.
fill_
(
Reducer
<
scalar_t
>::
init
(
REDUCE
));
for
(
auto
b
=
0
;
b
<
B
;
b
++
)
{
for
(
auto
e
=
0
;
e
<
E
;
e
++
)
{
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
i
=
b
*
E
*
K
+
e
*
K
+
k
;
idx
=
index_info
.
data
[
IndexToOffset
<
int64_t
>::
get
(
i
,
index_info
)];
Reducer
<
scalar_t
,
REDUCE
>::
update
(
out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
src_data
[
i
],
Reducer
<
scalar_t
>::
update
(
REDUCE
,
out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
src_data
[
i
],
arg_out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
e
);
}
}
}
if
(
!
optional_out
.
has_value
()
&&
(
REDUCE
==
MIN
||
REDUCE
==
MAX
))
out
.
masked_fill_
(
out
==
Reducer
<
scalar_t
,
REDUCE
>::
init
(),
(
scalar_t
)
0
);
out
.
masked_fill_
(
out
==
Reducer
<
scalar_t
>::
init
(
REDUCE
),
(
scalar_t
)
0
);
});
});
...
...
csrc/cpu/segment_coo_cpu.cpp
View file @
0be33ffa
...
...
@@ -72,7 +72,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
int64_t
idx
,
next_idx
,
row_start
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
if
(
!
optional_out
.
has_value
())
out
.
fill_
(
Reducer
<
scalar_t
,
REDUCE
>::
init
());
out
.
fill_
(
Reducer
<
scalar_t
>::
init
(
REDUCE
));
if
(
REDUCE
==
MEAN
)
count_data
=
arg_out
.
value
().
data_ptr
<
scalar_t
>
();
...
...
@@ -87,13 +87,13 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
for
(
auto
e
=
0
;
e
<
E
;
e
++
)
{
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
src_data
[
b
*
E
*
K
+
e
*
K
+
k
],
&
args
[
k
],
e
);
Reducer
<
scalar_t
>::
update
(
REDUCE
,
&
vals
[
k
],
src_data
[
b
*
E
*
K
+
e
*
K
+
k
],
&
args
[
k
],
e
);
if
(
e
==
E
-
1
)
{
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
Reducer
<
scalar_t
>::
write
(
REDUCE
,
out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
arg_out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
e
+
1
-
row_start
);
if
(
REDUCE
==
MEAN
)
...
...
@@ -104,8 +104,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if
(
idx
!=
next_idx
)
{
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
Reducer
<
scalar_t
>::
write
(
REDUCE
,
out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
vals
[
k
],
arg_out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
args
[
k
],
e
+
1
-
row_start
);
...
...
@@ -121,7 +121,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
}
}
if
(
!
optional_out
.
has_value
()
&&
(
REDUCE
==
MIN
||
REDUCE
==
MAX
))
out
.
masked_fill_
(
out
==
Reducer
<
scalar_t
,
REDUCE
>::
init
(),
(
scalar_t
)
0
);
out
.
masked_fill_
(
out
==
Reducer
<
scalar_t
>::
init
(
REDUCE
),
(
scalar_t
)
0
);
if
(
REDUCE
==
MEAN
)
arg_out
.
value
().
clamp_
(
1
);
...
...
csrc/cpu/segment_csr_cpu.cpp
View file @
0be33ffa
...
...
@@ -68,17 +68,17 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
vals
[
k
]
=
Reducer
<
scalar_t
>::
init
(
REDUCE
);
for
(
auto
e
=
row_start
;
e
<
row_end
;
e
++
)
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
src_data
[
offset
+
e
*
K
+
k
],
&
args
[
k
],
e
);
Reducer
<
scalar_t
>::
update
(
REDUCE
,
&
vals
[
k
],
src_data
[
offset
+
e
*
K
+
k
],
&
args
[
k
],
e
);
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
n
*
K
+
k
,
vals
[
k
],
arg_out_data
+
n
*
K
+
k
,
args
[
k
],
row_end
-
row_start
);
Reducer
<
scalar_t
>::
write
(
REDUCE
,
out_data
+
n
*
K
+
k
,
vals
[
k
],
arg_out_data
+
n
*
K
+
k
,
args
[
k
],
row_end
-
row_start
);
}
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment