Commit 906c97e4 authored by rusty1s's avatar rusty1s
Browse files

multi-thrading in SPMM (CPU)

parent 25700066
#include "spmm_cpu.h"
#include <ATen/Parallel.h>
#include "reducer.h"
#include "utils.h"
......@@ -47,19 +49,22 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, c;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VALUE) {
value_data = optional_value.value().data_ptr<scalar_t>();
}
for (auto b = 0; b < B; b++) {
for (auto m = 0; m < M; m++) {
int64_t grain_size = at::internal::GRAIN_SIZE / (K * (col.numel() / M));
at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
......@@ -86,7 +91,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
arg_out_data + offset + k,
args[k], row_end - row_start);
}
}
});
});
});
});
......
......@@ -5,6 +5,7 @@ import glob
from setuptools import setup, find_packages
import torch
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
......@@ -31,6 +32,16 @@ def get_extensions():
extra_compile_args = {'cxx': []}
extra_link_args = []
info = parallel_info()
if 'parallel backend: OpenMP' in info and 'OpenMP not found' not in info:
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/openmp']
else:
extra_compile_args['cxx'] += ['-fopenmp']
else:
print('Compiling without OpenMP...')
if WITH_CUDA:
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment