"src/vscode:/vscode.git/clone" did not exist on "02ba50c6104d40b745163fd14e84214b3db90112"
Commit 96b0abcb authored by rusty1s's avatar rusty1s
Browse files

setdiag implementation

parent ae35b8a5
#include <torch/extension.h>
#include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) {
CHECK_CPU(index);
int64_t E = index.size(1);
index = index.contiguous();
auto index_data = index.DATA_PTR<int64_t>();
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto mask = at::zeros(E + num_diag, index.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>();
int64_t r, c;
if (k < 0) {
for (int64_t i = 0; i < E; i++) {
r = index_data[i], c = index_data[i + E];
if (r + k < 0) {
mask_data[i] = true;
} else if (r + k >= N) {
mask_data[i + num_diag] = true;
} else if (r + k > c) {
mask_data[i + r + k] = true;
} else if (r + k < c) {
mask_data[i + r + k + 1] = true;
}
}
} else {
for (int64_t i = 0; i < E; i++) {
r = index_data[i], c = index_data[i + E];
if (r + k >= N) {
mask_data[i + num_diag] = true;
} else if (r + k > c) {
mask_data[i + r] = true;
} else if (r + k < c) {
mask_data[i + r + 1] = true;
}
}
}
return mask;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("non_diag_mask", &non_diag_mask, "Non-Diagonal Mask (CPU)");
}
......@@ -4,11 +4,11 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor rowptr(at::Tensor row, int64_t size) {
at::Tensor rowptr(at::Tensor row, int64_t M) {
CHECK_CPU(row);
AT_ASSERTM(row.dim() == 1, "Row needs to be one-dimensional");
auto out = at::empty(size + 1, row.options());
auto out = at::empty(M + 1, row.options());
auto row_data = row.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
......@@ -16,14 +16,14 @@ at::Tensor rowptr(at::Tensor row, int64_t size) {
for (int64_t i = 0; i <= idx; i++)
out_data[i] = 0;
for (int64_t i = 0; i < row.size(0) - 1; i++) {
for (int64_t i = 0; i < numel - 1; i++) {
next_idx = row_data[i + 1];
for (int64_t j = idx; j < next_idx; j++)
out_data[j + 1] = i + 1;
idx = next_idx;
}
for (int64_t i = idx + 1; i < size + 1; i++)
for (int64_t i = idx + 1; i < M + 1; i++)
out_data[i] = numel;
return out;
......
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor non_diag_mask_cuda(at::Tensor index, int64_t M, int64_t N,
int64_t k);
at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(index);
return non_diag_mask_cuda(index, M, N, k);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("non_diag_mask", &non_diag_mask, "Non-Diagonal Mask (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh"
#define THREADS 1024
__global__ void non_diag_mask_kernel(const int64_t *index_data, bool *out_data,
int64_t N, int64_t k, int64_t num_diag,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t r = index_data[thread_idx], c = index_data[thread_idx + numel];
if (k < 0) {
if (r + k < 0) {
out_data[thread_idx] = true;
} else if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r + k] = true;
} else if (r + k < c) {
out_data[thread_idx + r + k + 1] = true;
}
} else {
if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r] = true;
} else if (r + k < c) {
out_data[thread_idx + r + 1] = true;
}
}
}
}
at::Tensor non_diag_mask_cuda(at::Tensor index, int64_t M, int64_t N,
int64_t k) {
int64_t E = index.size(1);
index = index.contiguous();
auto index_data = index.DATA_PTR<int64_t>();
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto mask = at::zeros(E + num_diag, index.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>();
auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
index_data, mask_data, N, k, num_diag, E);
return mask;
}
......@@ -2,11 +2,11 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor rowptr_cuda(at::Tensor row, int64_t size);
at::Tensor rowptr_cuda(at::Tensor row, int64_t M);
at::Tensor rowptr(at::Tensor row, int64_t size) {
at::Tensor rowptr(at::Tensor row, int64_t M) {
CHECK_CUDA(row);
return rowptr_cuda(row, size);
return rowptr_cuda(row, M);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -3,10 +3,10 @@
#include "compat.cuh"
#define THREADS 256
#define THREADS 1024
__global__ void rowptr_kernel(const int64_t *row_data, int64_t *out_data,
int64_t numel, int64_t size) {
int64_t M, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -17,21 +17,21 @@ __global__ void rowptr_kernel(const int64_t *row_data, int64_t *out_data,
for (int64_t i = row_data[thread_idx - 1]; i < row_data[thread_idx]; i++)
out_data[i + 1] = thread_idx;
} else if (thread_idx == numel) {
for (int64_t i = row_data[numel - 1] + 1; i < size + 1; i++)
for (int64_t i = row_data[numel - 1] + 1; i < M + 1; i++)
out_data[i] = numel;
}
}
at::Tensor rowptr_cuda(at::Tensor row, int64_t size) {
at::Tensor rowptr_cuda(at::Tensor row, int64_t M) {
AT_ASSERTM(row.dim() == 1, "Row needs to be one-dimensional");
auto out = at::empty(size + 1, row.options());
auto out = at::empty(M + 1, row.options());
auto row_data = row.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
rowptr_kernel<<<(row.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
stream>>>(row_data, out_data, row.numel(), size);
stream>>>(row_data, out_data, M, row.numel());
return out;
}
......@@ -4,12 +4,20 @@ import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.diag_cpu import non_diag_mask
from .utils import dtypes, devices, tensor
dtypes = [torch.float]
devices = ['cpu']
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_remove_diag(dtype, device):
index = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
index = tensor([
[0, 0, 1, 2],
[0, 1, 2, 2],
], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value)
mat.fill_cache_()
......@@ -30,3 +38,52 @@ def test_remove_diag(dtype, device):
assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 0, 1]
assert mat.storage.colcount.tolist() == [1, 0, 1]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_set_diag(dtype, device):
index = tensor([
[0, 0, 9, 9],
[0, 1, 0, 1],
], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value)
print()
k = -8
print("k = ", k)
mat = mat.remove_diag(k)
print(mat.to_dense())
# row, col = mat.storage.index
# print('k', k)
# mask = row != col - k
# index = index[:, mask]
# row, col = index
# print(row)
# print(col)
mask = non_diag_mask(mat.storage.index, mat.size(0), mat.size(1), k)
print(mask)
# bla = col - row
# print(bla)
# DETECT VORZEICHEN WECHSEL
# mask = row.new_ones(index.size(1) + 3, dtype=torch.bool)
# mask[1:] = row[1:] != row[:-1]
# # mask = row[1:] != row[:-1]
# print(mask)
# mask = (row <= col)
# print(row)
# print(col)
# print(mask)
# mask = (row[1:] == row[:-1])
# print(mask)
# UNION
# idx1 = ...
# idx2 = ...
......@@ -54,13 +54,22 @@ def add(src, other):
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def add_(src, other):
pass
def add_nnz(src, other, layout=None):
if isinstance(other, int) or isinstance(other, float):
return src.set_value(src.storage.value + other if src.has_value(
) else torch.full((src.nnz(), ), 1 + other, device=src.device))
return src.set_value(src.storage.value +
other if src.has_value() else torch.full((
src.nnz(), ), 1 + other, device=src.device))
elif torch.is_tensor(other):
return src.set_value(src.storage.value +
other if src.has_value() else other + 1)
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
def add_nnz_(src, other, layout=None):
pass
import torch
from torch_sparse import diag_cpu
try:
from torch_sparse import diag_cuda
except ImportError:
diag_cuda = None
def remove_diag(src, k=0):
index, value = src.coo()
row, col = index
......@@ -17,15 +27,11 @@ def remove_diag(src, k=0):
rowcount = src.storage.rowcount.clone()
rowcount[row[mask]] -= 1
# TODO: Maintain `rowptr`.
colcount = None
if src.storage.has_colcount():
colcount = src.storage.colcount.clone()
colcount[col[mask]] -= 1
# TODO: Maintain `colptr`.
storage = src.storage.__class__(index, value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
......@@ -34,4 +40,44 @@ def remove_diag(src, k=0):
def set_diag(src, value=None, k=0):
pass
src = src.remove_diag(k=0)
index, value = src.coo()
func = diag_cuda if index.is_cuda else diag_cpu
mask = func.non_diag_mask(index, src.size(0), src.size(1), k)
inv_mask = ~mask
new_index = index.new_empty((2, mask.size(0)))
new_index[:, mask] = index
num_diag = mask.numel() - index.size(1)
start = -k if k < 0 else 0
diag_row = torch.arange(start, start + num_diag, device=src.device)
new_index[0, inv_mask] = diag_row
diag_col = diag_row.add_(k)
new_index[1, inv_mask] = diag_col
new_value = None
if src.has_value():
new_value = torch.new_empty((mask.size(0), ) + mask.size()[1:])
new_value[mask] = value
new_value[inv_mask] = 1
rowcount = None
if src.storage.has_rowcount():
rowcount = src.storage.rowcount.clone()
rowcount[start:start + num_diag] += 1
colcount = None
if src.storage.has_colcount():
colcount = src.storage.colcount.clone()
colcount[start + k:start + num_diag + k] += 1
storage = src.storage.__class__(index, value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True)
return src.__class__.from_storage(storage)
import torch
def matmul(src, other, reduce='add'):
if torch.is_tensor(other):
pass
if isinstance(other, src.__class__):
if reduce != 'add':
raise NotImplementedError(
(f'Reduce argument "{reduce}" not implemented for sparse-'
f'sparse matrix multiplication'))
......@@ -5,8 +5,10 @@ from torch_scatter import segment_csr, scatter_add
from torch_sparse import rowptr_cpu
if torch.cuda.is_available():
try:
from torch_sparse import rowptr_cuda
except ImportError:
rowptr_cuda = None
__cache__ = {'enabled': True}
......@@ -190,10 +192,48 @@ class SparseStorage(object):
def sparse_size(self, dim=None):
return self._sparse_size if dim is None else self._sparse_size[dim]
def sparse_resize_(self, *sizes):
def sparse_resize(self, *sizes):
assert len(sizes) == 2
self._sparse_size = sizes
return self
old_sizes, nnz = self.sparse_size(), self.nnz()
diff_0 = sizes[0] - old_sizes[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if self.has_rowcount():
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
if self.has_rowptr():
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
else:
if self.has_rowcount():
rowcount = rowcount[:-diff_0]
if self.has_rowptr():
rowptr = rowptr[:-diff_0]
diff_1 = sizes[1] - old_sizes[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if self.has_colcount():
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
if self.has_colptr():
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
else:
if self.has_colcount():
colcount = colcount[:-diff_1]
if self.has_colptr():
colptr = colptr[:-diff_1]
return self.__class__(
self._index,
self._value,
sizes,
rowcount=rowcount,
rowptr=rowptr,
colcount=colcount,
colptr=colptr,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
)
def has_rowcount(self):
return self._rowcount is not None
......
......@@ -100,9 +100,8 @@ class SparseTensor(object):
def sparse_size(self, dim=None):
return self.storage.sparse_size(dim)
def sparse_resize_(self, *sizes):
self.storage.sparse_resize_(*sizes)
return self
def sparse_resize(self, *sizes):
return self.from_storage(self.storage.sparse_resize(*sizes))
def is_coalesced(self):
return self.storage.is_coalesced()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment