Commit e68dcf3b authored by rusty1s's avatar rusty1s
Browse files

parallel_for

parent 8b30d1c7
#include "rw_cpu.h" #include "rw_cpu.h"
#include <ATen/ParallelOpenMP.h>
#include "utils.h" #include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
...@@ -24,24 +26,27 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -24,24 +26,27 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto rand_data = rand.data_ptr<float>(); auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>(); auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) { int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
auto cur = start_data[n]; at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) {
auto offset = n * (walk_length + 1); for (auto n = b; n < e; n++) {
out_data[offset] = cur; auto cur = start_data[n];
auto offset = n * (walk_length + 1);
int64_t row_start, row_end, rnd; out_data[offset] = cur;
for (auto l = 1; l <= walk_length; l++) {
row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1]; int64_t row_start, row_end, rnd;
if (row_end - row_start == 0) { for (auto l = 1; l <= walk_length; l++) {
cur = n; row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1];
} else { if (row_end - row_start == 0) {
rnd = int64_t(rand_data[n * walk_length + (l - 1)] * cur = n;
(row_end - row_start)); } else {
cur = col_data[row_start + rnd]; rnd = int64_t(rand_data[n * walk_length + (l - 1)] *
(row_end - row_start));
cur = col_data[row_start + rnd];
}
out_data[offset + l] = cur;
} }
out_data[offset + l] = cur;
} }
} });
return out; return out;
} }
...@@ -19,7 +19,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' ...@@ -19,7 +19,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions(): def get_extensions():
Extension = CppExtension Extension = CppExtension
define_macros = [] define_macros = []
extra_compile_args = {'cxx': []} extra_compile_args = {'cxx': ['-fopenmp']}
extra_link_args = ['-lgomp']
if WITH_CUDA: if WITH_CUDA:
Extension = CUDAExtension Extension = CUDAExtension
...@@ -51,6 +52,7 @@ def get_extensions(): ...@@ -51,6 +52,7 @@ def get_extensions():
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
) )
extensions += [extension] extensions += [extension]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment