Commit 906c97e4 authored by rusty1s's avatar rusty1s
Browse files

multi-thrading in SPMM (CPU)

parent 25700066
#include "spmm_cpu.h" #include "spmm_cpu.h"
#include <ATen/Parallel.h>
#include "reducer.h" #include "reducer.h"
#include "utils.h" #include "utils.h"
...@@ -47,19 +49,22 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -47,19 +49,22 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto mat_data = mat.data_ptr<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, c;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] { AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VALUE) { if (HAS_VALUE) {
value_data = optional_value.value().data_ptr<scalar_t>(); value_data = optional_value.value().data_ptr<scalar_t>();
} }
for (auto b = 0; b < B; b++) { int64_t grain_size = at::internal::GRAIN_SIZE / (K * (col.numel() / M));
for (auto m = 0; m < M; m++) { at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1]; row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
...@@ -86,7 +91,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -86,7 +91,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
arg_out_data + offset + k, arg_out_data + offset + k,
args[k], row_end - row_start); args[k], row_end - row_start);
} }
} });
}); });
}); });
}); });
......
...@@ -5,6 +5,7 @@ import glob ...@@ -5,6 +5,7 @@ import glob
from setuptools import setup, find_packages from setuptools import setup, find_packages
import torch import torch
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
...@@ -31,6 +32,16 @@ def get_extensions(): ...@@ -31,6 +32,16 @@ def get_extensions():
extra_compile_args = {'cxx': []} extra_compile_args = {'cxx': []}
extra_link_args = [] extra_link_args = []
info = parallel_info()
if 'parallel backend: OpenMP' in info and 'OpenMP not found' not in info:
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/openmp']
else:
extra_compile_args['cxx'] += ['-fopenmp']
else:
print('Compiling without OpenMP...')
if WITH_CUDA: if WITH_CUDA:
Extension = CUDAExtension Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment