Commit 0a06b39b authored by rusty1s's avatar rusty1s
Browse files

spmm cpu implementation and benchmark script

parent e8d349ee
import time
import os.path as osp
import itertools
import argparse
import wget
import torch
from scipy.io import loadmat
from torch_sparse import spmm_cpu
from torch_sparse.tensor import SparseTensor
short_rows = [
('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'),
]
long_rows = [
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
def download(dataset):
url = 'https://sparse.tamu.edu/mat/{}/{}.mat'
for group, name in itertools.chain(long_rows, short_rows):
if not osp.exists(f'{name}.mat'):
print(f'Downloading {group}/{name}:')
wget.download(url.format(group, name))
print('')
def bold(text, flag=True):
return f'\033[1m{text}\033[0m' if flag else text
@torch.no_grad()
def correctness(dataset):
pass
def time_func(func, x):
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter()
if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
out = out[0] if isinstance(out, tuple) else out
torch.autograd.grad(out, x, out, only_inputs=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
return float('inf')
def timing(dataset):
group, name = dataset
mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
index = torch.stack([row, col], dim=0)
mat_own = SparseTensor(index, sparse_size=mat_scipy.shape)
rowptr, col, value = mat_own.csr()
mat_pytorch = mat_own.to_torch_sparse_coo_tensor().coalesce()
def spmm_scipy(x):
return mat_scipy @ x
def spmm_pytorch(x):
return mat_pytorch @ x
def spmm_own(x):
return spmm_cpu.spmm(rowptr, col, value, x, 'sum')
t1, t2, t3 = [], [], []
for size in sizes:
try:
x = torch.randn((mat_own.size(1), size), device=args.device)
t1 += [time_func(spmm_scipy, x)]
t2 += [time_func(spmm_pytorch, x)]
t3 += [time_func(spmm_own, x)]
del x
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t1, t2, t3):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {mat_own.avg_row_length():.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SPMM SciPy ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('SPMM PyTorch')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('SPMM Own ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:4] if args.device == 'cpu' else sizes
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
#include <torch/extension.h>
#include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
#define AT_DISPATCH_HAS_VAL(value_opt, ...) \
[&] { \
switch (value_opt.has_value()) { \
case true: { \
const bool HAS_VAL = true; \
return __VA_ARGS__(); \
} \
case false: { \
const bool HAS_VAL = false; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (count > 0 ? count : (scalar_t)1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
if (value_opt.has_value())
CHECK_CPU(value_opt.value());
CHECK_CPU(mat);
mat = mat.contiguous();
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(col.dim() == 1, "Input mismatch");
if (value_opt.has_value())
AT_ASSERTM(value_opt.value().dim() == 1);
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(rowptr.numel() - 1 == mat.size(-2), "Input mismatch");
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = at::empty(sizes, mat.options());
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, mat.size(-2), rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
int N = rowptr.numel() - 1;
int M = mat.size(-2);
int K = mat.size(-1);
int B = mat.numel() / (M * K);
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr;
auto mat_data = out.DATA_PTR<scalar_t>();
auto out_data = mat.DATA_PTR<scalar_t>();
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, col_idx;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VAL(value_opt, [&] {
if (HAS_VAL) {
value_data = value_opt.value().DATA_PTR<scalar_t>();
}
for (int b = 0; b < B; b++) {
for (int n = 0; n < N; n++) {
row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
for (int k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
int offset = b * M * K;
for (int e = row_start; e < row_end; e++) {
col_idx = col_data[e];
if (HAS_VAL)
val = value_data[e];
for (int k = 0; k < K; k++) {
if (HAS_VAL)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + col_idx * K + k],
&args[k], e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + col_idx * K + k], &args[k],
e);
}
}
offset = b * N * K + n * K;
for (int k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
arg_out_data + offset + k,
args[k], row_end - row_start);
}
}
});
});
});
return std::make_tuple(out, arg_out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse-Dense Matrix Multiplication (CPU)");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment