Commit 817b767e authored by rusty1s's avatar rusty1s
Browse files

parallelize CPU fps over batch dimension

parent 4b01cc80
#include "fps_cpu.h"
#include <ATen/Parallel.h>
#include "utils.h"
inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
......@@ -28,27 +30,29 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
auto out_ptr_data = out_ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
int64_t src_start = 0, out_start = 0, src_end, out_end;
for (auto b = 0; b < batch_size; b++) {
src_end = ptr_data[b + 1], out_end = out_ptr_data[b];
auto y = src.narrow(0, src_start, src_end - src_start);
int64_t grain_size = 1; // Always parallelize over batch dimension.
at::parallel_for(0, batch_size, grain_size, [&](int64_t begin, int64_t end) {
int64_t src_start, src_end, out_start, out_end;
for (int64_t b = begin; b < end; b++) {
src_start = ptr_data[b], src_end = ptr_data[b + 1];
out_start = b == 0 ? 0 : out_ptr_data[b - 1], out_end = out_ptr_data[b];
int64_t start_idx = 0;
if (random_start) {
start_idx = rand() % y.size(0);
}
auto y = src.narrow(0, src_start, src_end - src_start);
out_data[out_start] = src_start + start_idx;
auto dist = get_dist(y, start_idx);
int64_t start_idx = 0;
if (random_start)
start_idx = rand() % y.size(0);
for (auto i = 1; i < out_end - out_start; i++) {
int64_t argmax = dist.argmax().data_ptr<int64_t>()[0];
out_data[out_start + i] = src_start + argmax;
dist = torch::min(dist, get_dist(y, argmax));
}
out_data[out_start] = src_start + start_idx;
auto dist = get_dist(y, start_idx);
src_start = src_end, out_start = out_end;
}
for (int64_t i = 1; i < out_end - out_start; i++) {
int64_t argmax = dist.argmax().data_ptr<int64_t>()[0];
out_data[out_start + i] = src_start + argmax;
dist = torch::min(dist, get_dist(y, argmax));
}
}
});
return out;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment