Commit e68dcf3b authored by rusty1s's avatar rusty1s
Browse files

parallel_for

parent 8b30d1c7
#include "rw_cpu.h"
#include <ATen/ParallelOpenMP.h>
#include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
......@@ -24,7 +26,9 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) {
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) {
for (auto n = b; n < e; n++) {
auto cur = start_data[n];
auto offset = n * (walk_length + 1);
out_data[offset] = cur;
......@@ -42,6 +46,7 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
out_data[offset + l] = cur;
}
}
});
return out;
}
......@@ -19,7 +19,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': []}
extra_compile_args = {'cxx': ['-fopenmp']}
extra_link_args = ['-lgomp']
if WITH_CUDA:
Extension = CUDAExtension
......@@ -51,6 +52,7 @@ def get_extensions():
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
extensions += [extension]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment