Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
817b767e
Commit
817b767e
authored
Dec 03, 2020
by
rusty1s
Browse files
parallelize CPU fps over batch dimension
parent
4b01cc80
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
17 deletions
+21
-17
csrc/cpu/fps_cpu.cpp
csrc/cpu/fps_cpu.cpp
+21
-17
No files found.
csrc/cpu/fps_cpu.cpp
View file @
817b767e
#include "fps_cpu.h"
#include <ATen/Parallel.h>
#include "utils.h"
inline
torch
::
Tensor
get_dist
(
torch
::
Tensor
x
,
int64_t
idx
)
{
...
...
@@ -28,27 +30,29 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
auto
out_ptr_data
=
out_ptr
.
data_ptr
<
int64_t
>
();
auto
out_data
=
out
.
data_ptr
<
int64_t
>
();
int64_t
src_start
=
0
,
out_start
=
0
,
src_end
,
out_end
;
for
(
auto
b
=
0
;
b
<
batch_size
;
b
++
)
{
src_end
=
ptr_data
[
b
+
1
],
out_end
=
out_ptr_data
[
b
];
auto
y
=
src
.
narrow
(
0
,
src_start
,
src_end
-
src_start
);
int64_t
grain_size
=
1
;
// Always parallelize over batch dimension.
at
::
parallel_for
(
0
,
batch_size
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
src_start
,
src_end
,
out_start
,
out_end
;
for
(
int64_t
b
=
begin
;
b
<
end
;
b
++
)
{
src_start
=
ptr_data
[
b
],
src_end
=
ptr_data
[
b
+
1
];
out_start
=
b
==
0
?
0
:
out_ptr_data
[
b
-
1
],
out_end
=
out_ptr_data
[
b
];
int64_t
start_idx
=
0
;
if
(
random_start
)
{
start_idx
=
rand
()
%
y
.
size
(
0
);
}
auto
y
=
src
.
narrow
(
0
,
src_start
,
src_end
-
src_start
);
out_data
[
out_start
]
=
src_start
+
start_idx
;
auto
dist
=
get_dist
(
y
,
start_idx
);
int64_t
start_idx
=
0
;
if
(
random_start
)
start_idx
=
rand
()
%
y
.
size
(
0
);
for
(
auto
i
=
1
;
i
<
out_end
-
out_start
;
i
++
)
{
int64_t
argmax
=
dist
.
argmax
().
data_ptr
<
int64_t
>
()[
0
];
out_data
[
out_start
+
i
]
=
src_start
+
argmax
;
dist
=
torch
::
min
(
dist
,
get_dist
(
y
,
argmax
));
}
out_data
[
out_start
]
=
src_start
+
start_idx
;
auto
dist
=
get_dist
(
y
,
start_idx
);
src_start
=
src_end
,
out_start
=
out_end
;
}
for
(
int64_t
i
=
1
;
i
<
out_end
-
out_start
;
i
++
)
{
int64_t
argmax
=
dist
.
argmax
().
data_ptr
<
int64_t
>
()[
0
];
out_data
[
out_start
+
i
]
=
src_start
+
argmax
;
dist
=
torch
::
min
(
dist
,
get_dist
(
y
,
argmax
));
}
}
});
return
out
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment