Unverified Commit 7671fcb0 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #33 from rusty1s/adj

[WIP] SparseTensor Format
parents 1fb5fa4f 704ad420
......@@ -13,24 +13,6 @@ jobs:
env:
- CC=gcc-5
- CXX=g++-5
- os: osx
language: sh
before_cache:
- brew cleanup
cache:
directories:
- $HOME/Library/Caches/Homebrew
- /usr/local/Homebrew
addons:
homebrew:
packages: python3
before_install:
- python3 -m pip install --upgrade virtualenv
- virtualenv -p python3 --system-site-packages "$HOME/venv"
- source "$HOME/venv/bin/activate"
env:
- CC=clang
- CXX=clang++
install:
- pip install numpy
- pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
......@@ -40,7 +22,7 @@ install:
- pip install torch-scatter
script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
- pycodestyle --ignore=E731,W504 .
- flake8 .
- python setup.py install
- python setup.py test
......
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de>
Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
......
include LICENSE
recursive-include cpu *
recursive-include cuda *
recursive-include csrc *
import time
import os.path as osp
import itertools
import argparse
import wget
import torch
from scipy.io import loadmat
from torch_scatter import scatter_add
from torch_sparse.tensor import SparseTensor
short_rows = [
('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'),
]
long_rows = [
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
def download(dataset):
url = 'https://sparse.tamu.edu/mat/{}/{}.mat'
for group, name in itertools.chain(long_rows, short_rows):
if not osp.exists(f'{name}.mat'):
print(f'Downloading {group}/{name}:')
wget.download(url.format(group, name))
print('')
def bold(text, flag=True):
return f'\033[1m{text}\033[0m' if flag else text
@torch.no_grad()
def correctness(dataset):
group, name = dataset
mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
mat.fill_cache_()
mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
for size in sizes:
try:
x = torch.randn((mat.size(1), size), device=args.device)
out1 = mat @ x
out2 = mat_pytorch @ x
assert torch.allclose(out1, out2, atol=1e-4)
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
def time_func(func, x):
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter()
if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
out = out[0] if isinstance(out, tuple) else out
torch.autograd.grad(out, x, out, only_inputs=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
return float('inf')
def timing(dataset):
group, name = dataset
mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
mat.fill_cache_()
mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
mat_scipy = mat.to_scipy(layout='csr')
def scatter(x):
return scatter_add(x[col], row, dim=0, dim_size=mat_scipy.shape[0])
def spmm_scipy(x):
if x.is_cuda:
raise RuntimeError('out of memory')
return mat_scipy @ x
def spmm_pytorch(x):
return mat_pytorch @ x
def spmm(x):
return mat @ x
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try:
x = torch.randn((mat.size(1), size), device=args.device)
t1 += [time_func(scatter, x)]
t2 += [time_func(spmm_scipy, x)]
t3 += [time_func(spmm_pytorch, x)]
t4 += [time_func(spmm, x)]
del x
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t1, t2, t3, t4):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {mat.avg_row_length():.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('Scatter ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('SPMM SciPy ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('SPMM PyTorch')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('SPMM Own ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:4] if args.device == 'cpu' else sizes
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#include "compat.h"
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
// Assert already coalesced input.
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
size_t rowB_max) {
int64_t *index_data = index.DATA_PTR<int64_t>();
auto value = at::zeros(index.size(1), valueA.options());
at::Tensor rowA, colA;
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
int64_t *rowA_data = rowA.DATA_PTR<int64_t>();
int64_t *colA_data = colA.DATA_PTR<int64_t>();
at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
int64_t *rowB_data = rowB.DATA_PTR<int64_t>();
int64_t *colB_data = colB.DATA_PTR<int64_t>();
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
scalar_t *value_data = value.DATA_PTR<scalar_t>();
scalar_t *valueA_data = valueA.DATA_PTR<scalar_t>();
scalar_t *valueB_data = valueB.DATA_PTR<scalar_t>();
for (int64_t e = 0; e < value.size(0); e++) {
int64_t i = index_data[e], j = index_data[value.size(0) + e];
for (ptrdiff_t dA = rowA_data[i]; dA < rowA_data[i + 1]; dA++) {
int64_t cA = colA_data[dA];
for (ptrdiff_t dB = rowB_data[j]; dB < rowB_data[j + 1]; dB++) {
int64_t cB = colB_data[dB];
if (cA == cB) {
value_data[e] += valueA_data[dA] * valueB_data[dB];
}
if (cB >= cA) {
break;
}
}
}
}
});
return value;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm_bw", &spspmm_bw,
"Sparse-Sparse Matrix Multiplication Backward (CPU)");
}
#include <torch/script.h>
#include "cpu/convert_cpu.h"
#ifdef WITH_CUDA
#include "cuda/convert_cuda.h"
#endif
torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {
if (ind.device().is_cuda()) {
#ifdef WITH_CUDA
return ind2ptr_cuda(ind, M);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return ind2ptr_cpu(ind, M);
}
}
torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E) {
if (ptr.device().is_cuda()) {
#ifdef WITH_CUDA
return ptr2ind_cuda(ptr, E);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return ptr2ind_cpu(ptr, E);
}
}
static auto registry = torch::RegisterOperators()
.op("torch_sparse::ind2ptr", &ind2ptr)
.op("torch_sparse::ptr2ind", &ptr2ind);
#include "convert_cpu.h"
#include "utils.h"
torch::Tensor ind2ptr_cpu(torch::Tensor ind, int64_t M) {
CHECK_CPU(ind);
auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
int64_t numel = ind.numel(), idx = ind_data[0], next_idx;
for (auto i = 0; i <= idx; i++)
out_data[i] = 0;
for (auto i = 0; i < numel - 1; i++) {
next_idx = ind_data[i + 1];
for (auto j = idx; j < next_idx; j++)
out_data[j + 1] = i + 1;
idx = next_idx;
}
for (auto i = idx + 1; i < M + 1; i++)
out_data[i] = numel;
return out;
}
torch::Tensor ptr2ind_cpu(torch::Tensor ptr, int64_t E) {
CHECK_CPU(ptr);
auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
int64_t idx = ptr_data[0], next_idx;
for (auto i = 0; i < ptr.numel() - 1; i++) {
next_idx = ptr_data[i + 1];
for (auto e = idx; e < next_idx; e++)
out_data[e] = i;
idx = next_idx;
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor ind2ptr_cpu(torch::Tensor ind, int64_t M);
torch::Tensor ptr2ind_cpu(torch::Tensor ptr, int64_t E);
#include "diag_cpu.h"
#include "utils.h"
torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k) {
CHECK_CPU(row);
CHECK_CPU(col);
auto E = row.size(0);
auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();
int64_t r, c;
if (k < 0) {
for (int64_t i = 0; i < E; i++) {
r = row_data[i], c = col_data[i];
if (r + k < 0) {
mask_data[i] = true;
} else if (r + k >= N) {
mask_data[i + num_diag] = true;
} else if (r + k > c) {
mask_data[i + r + k] = true;
} else if (r + k < c) {
mask_data[i + r + k + 1] = true;
}
}
} else {
for (int64_t i = 0; i < E; i++) {
r = row_data[i], c = col_data[i];
if (r + k >= N) {
mask_data[i + num_diag] = true;
} else if (r + k > c) {
mask_data[i + r] = true;
} else if (r + k < c) {
mask_data[i + r + 1] = true;
}
}
}
return mask;
}
#pragma once
#include <torch/extension.h>
torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k);
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline void update(ReductionType REDUCE, scalar_t *val,
scalar_t new_val, int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(ReductionType REDUCE, scalar_t *address,
scalar_t val, int64_t *arg_address, int64_t arg,
int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include "spmm_cpu.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
if (optional_value.has_value())
CHECK_CPU(optional_value.value());
CHECK_CPU(mat);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
if (optional_value.has_value()) {
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().size(0) == col.size(0));
}
CHECK_INPUT(mat.dim() >= 2);
mat = mat.contiguous();
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr;
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, c;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VALUE) {
value_data = optional_value.value().data_ptr<scalar_t>();
}
for (auto b = 0; b < B; b++) {
for (auto m = 0; m < M; m++) {
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t>::init(REDUCE);
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
c = col_data[e];
if (HAS_VALUE)
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t>::update(REDUCE, &vals[k],
val * mat_data[offset + c * K + k],
&args[k], e);
else
Reducer<scalar_t>::update(REDUCE, &vals[k],
mat_data[offset + c * K + k],
&args[k], e);
}
}
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t>::write(REDUCE, out_data + offset + k, vals[k],
arg_out_data + offset + k, args[k],
row_end - row_start);
}
}
});
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CPU(row);
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(mat);
CHECK_CPU(grad);
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto out = torch::zeros(row.numel(), grad.options());
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_value_bw", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
int64_t row, col;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int e = 0; e < E; e++) {
row = row_data[e], col = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
if (REDUCE == MEAN) {
int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1);
}
out_data[e] += val;
}
}
});
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);
torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
#include "spspmm_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CPU(rowptrA);
CHECK_CPU(colA);
if (optional_valueA.has_value())
CHECK_CPU(optional_valueA.value());
CHECK_CPU(rowptrB);
CHECK_CPU(colB);
if (optional_valueB.has_value())
CHECK_CPU(optional_valueB.value());
CHECK_INPUT(rowptrA.dim() == 1);
CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto rowptrA_data = rowptrA.data_ptr<int64_t>();
auto colA_data = colA.data_ptr<int64_t>();
auto rowptrB_data = rowptrB.data_ptr<int64_t>();
auto colB_data = colB.data_ptr<int64_t>();
// Pass 1: Compute CSR row pointer.
auto rowptrC = torch::empty_like(rowptrA);
auto rowptrC_data = rowptrC.data_ptr<int64_t>();
rowptrC_data[0] = 0;
std::vector<int64_t> mask(K, -1);
int64_t nnz = 0, row_nnz, rowA_start, rowA_end, rowB_start, rowB_end, cA, cB;
for (auto n = 0; n < rowptrA.numel() - 1; n++) {
row_nnz = 0;
for (auto eA = rowptrA_data[n]; eA < rowptrA_data[n + 1]; eA++) {
cA = colA_data[eA];
for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
cB = colB_data[eB];
if (mask[cB] != n) {
mask[cB] = n;
row_nnz++;
}
}
}
nnz += row_nnz;
rowptrC_data[n + 1] = nnz;
}
// Pass 2: Compute CSR entries.
auto colC = torch::empty(nnz, rowptrC.options());
auto colC_data = colC.data_ptr<int64_t>();
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
if (optional_valueA.has_value())
optional_valueC = torch::empty(nnz, optional_valueA.value().options());
AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] {
AT_DISPATCH_HAS_VALUE(optional_valueC, [&] {
scalar_t *valA_data = nullptr, *valB_data = nullptr, *valC_data = nullptr;
if (HAS_VALUE) {
valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>();
valC_data = optional_valueC.value().data_ptr<scalar_t>();
}
scalar_t valA;
rowA_start = 0, nnz = 0;
std::vector<scalar_t> vals(K, 0);
for (auto n = 1; n < rowptrA.numel(); n++) {
rowA_end = rowptrA_data[n];
for (auto eA = rowA_start; eA < rowA_end; eA++) {
cA = colA_data[eA];
if (HAS_VALUE)
valA = valA_data[eA];
rowB_start = rowptrB_data[cA], rowB_end = rowptrB_data[cA + 1];
for (auto eB = rowB_start; eB < rowB_end; eB++) {
cB = colB_data[eB];
if (HAS_VALUE)
vals[cB] += valA * valB_data[eB];
else
vals[cB] += 1;
}
}
for (auto k = 0; k < K; k++) {
if (vals[k] != 0) {
colC_data[nnz] = k;
if (HAS_VALUE)
valC_data[nnz] = vals[k];
nnz++;
}
vals[k] = (scalar_t)0;
}
rowA_start = rowA_end;
}
});
});
return std::make_tuple(rowptrC, colC, optional_valueC);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define AT_DISPATCH_HAS_VALUE(optional_value, ...) \
[&] { \
if (optional_value.has_value()) { \
const bool HAS_VALUE = true; \
return __VA_ARGS__(); \
} else { \
const bool HAS_VALUE = false; \
return __VA_ARGS__(); \
} \
}()
#include "convert_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 256
__global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
int64_t M, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx == 0) {
for (int64_t i = 0; i <= ind_data[0]; i++)
out_data[i] = 0;
} else if (thread_idx < numel) {
for (int64_t i = ind_data[thread_idx - 1]; i < ind_data[thread_idx]; i++)
out_data[i + 1] = thread_idx;
} else if (thread_idx == numel) {
for (int64_t i = ind_data[numel - 1] + 1; i < M + 1; i++)
out_data[i] = numel;
}
}
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
CHECK_CUDA(ind);
cudaSetDevice(ind.get_device());
auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
ind2ptr_kernel<<<(ind.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
stream>>>(ind_data, out_data, M, ind.numel());
return out;
}
__global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
int64_t E, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t idx = ptr_data[thread_idx], next_idx = ptr_data[thread_idx + 1];
for (int64_t i = idx; i < next_idx; i++) {
out_data[i] = thread_idx;
}
}
}
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
CHECK_CUDA(ptr);
cudaSetDevice(ptr.get_device());
auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
ptr2ind_kernel<<<(ptr.numel() - 1 + THREADS - 1) / THREADS, THREADS, 0,
stream>>>(ptr_data, out_data, E, ptr.numel() - 1);
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M);
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment