"vscode:/vscode.git/clone" did not exist on "cc92a4b47dc45a6badb384ce2c68e43940e380fa"
Commit e68dcf3b authored by rusty1s's avatar rusty1s
Browse files

parallel_for

parent 8b30d1c7
#include "rw_cpu.h" #include "rw_cpu.h"
#include <ATen/ParallelOpenMP.h>
#include "utils.h" #include "utils.h"
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
...@@ -24,7 +26,9 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -24,7 +26,9 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto rand_data = rand.data_ptr<float>(); auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>(); auto out_data = out.data_ptr<int64_t>();
for (auto n = 0; n < start.size(0); n++) { int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) {
for (auto n = b; n < e; n++) {
auto cur = start_data[n]; auto cur = start_data[n];
auto offset = n * (walk_length + 1); auto offset = n * (walk_length + 1);
out_data[offset] = cur; out_data[offset] = cur;
...@@ -42,6 +46,7 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -42,6 +46,7 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
out_data[offset + l] = cur; out_data[offset + l] = cur;
} }
} }
});
return out; return out;
} }
...@@ -19,7 +19,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' ...@@ -19,7 +19,8 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions(): def get_extensions():
Extension = CppExtension Extension = CppExtension
define_macros = [] define_macros = []
extra_compile_args = {'cxx': []} extra_compile_args = {'cxx': ['-fopenmp']}
extra_link_args = ['-lgomp']
if WITH_CUDA: if WITH_CUDA:
Extension = CUDAExtension Extension = CUDAExtension
...@@ -51,6 +52,7 @@ def get_extensions(): ...@@ -51,6 +52,7 @@ def get_extensions():
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
) )
extensions += [extension] extensions += [extension]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment