Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
5db00866
Commit
5db00866
authored
Jan 12, 2020
by
rusty1s
Browse files
faster segment csr cpu implementation
parent
3994f3ab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
7 deletions
+11
-7
cpu/segment.cpp
cpu/segment.cpp
+11
-7
No files found.
cpu/segment.cpp
View file @
5db00866
...
@@ -123,8 +123,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
...
@@ -123,8 +123,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
val
;
scalar_t
val
s
[
K
]
;
int64_t
row_start
,
row_end
,
arg
;
int64_t
row_start
,
row_end
,
arg
s
[
K
]
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
n
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
n
,
indptr_info
);
...
@@ -133,13 +133,17 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
...
@@ -133,13 +133,17 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
}
for
(
int64_t
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int64_t
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
offset
+
e
*
K
+
k
],
&
arg
,
e
);
&
val
s
[
k
]
,
src_data
[
offset
+
e
*
K
+
k
],
&
arg
s
[
k
]
,
e
);
}
}
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
n
*
K
+
k
,
val
,
}
arg_out_data
+
n
*
K
+
k
,
arg
,
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
n
*
K
+
k
,
vals
[
k
],
arg_out_data
+
n
*
K
+
k
,
args
[
k
],
row_end
-
row_start
);
row_end
-
row_start
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment