Commit e68dcf3b authored by rusty1s's avatar rusty1s
Browse files

parallel_for

parent 8b30d1c7
#include "rw_cpu.h"
#include <ATen/ParallelOpenMP.h>
#include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
......@@ -24,24 +26,27 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) {
auto cur = start_data[n];
auto offset = n * (walk_length + 1);
out_data[offset] = cur;
int64_t row_start, row_end, rnd;
for (auto l = 1; l <= walk_length; l++) {
row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1];
if (row_end - row_start == 0) {
cur = n;
} else {
rnd = int64_t(rand_data[n * walk_length + (l - 1)] *
(row_end - row_start));
cur = col_data[row_start + rnd];
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) {
for (auto n = b; n < e; n++) {
auto cur = start_data[n];
auto offset = n * (walk_length + 1);
out_data[offset] = cur;
int64_t row_start, row_end, rnd;
for (auto l = 1; l <= walk_length; l++) {
row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1];
if (row_end - row_start == 0) {
cur = n;
} else {
rnd = int64_t(rand_data[n * walk_length + (l - 1)] *
(row_end - row_start));
cur = col_data[row_start + rnd];
}
out_data[offset + l] = cur;
}
out_data[offset + l] = cur;
}
}
});
return out;
}
......@@ -19,7 +19,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': []}
extra_compile_args = {'cxx': ['-fopenmp']}
extra_link_args = ['-lgomp']
if WITH_CUDA:
Extension = CUDAExtension
......@@ -51,6 +52,7 @@ def get_extensions():
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
extensions += [extension]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment