Commit 1be2aaf7 authored by rusty1s's avatar rusty1s
Browse files

cpu fixes

parent 516527f1
......@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
template <typename scalar_t> struct Reducer {
static inline scalar_t init(ReductionType REDUCE) {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
......@@ -52,8 +52,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
static inline void update(ReductionType REDUCE, scalar_t *val,
scalar_t new_val, int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
......@@ -67,8 +67,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
static inline void write(ReductionType REDUCE, scalar_t *address,
scalar_t val, int64_t *arg_address, int64_t arg,
int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
......
......@@ -63,7 +63,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
vals[k] = Reducer<scalar_t>::init(REDUCE);
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
......@@ -72,19 +72,20 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k], &args[k],
e);
Reducer<scalar_t>::update(REDUCE, &vals[k],
val * mat_data[offset + c * K + k],
&args[k], e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k], e);
Reducer<scalar_t>::update(REDUCE, &vals[k],
mat_data[offset + c * K + k],
&args[k], e);
}
}
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
arg_out_data + offset + k,
args[k], row_end - row_start);
Reducer<scalar_t>::write(REDUCE, out_data + offset + k, vals[k],
arg_out_data + offset + k, args[k],
row_end - row_start);
}
}
});
......
......@@ -166,7 +166,7 @@ class SparseTensor(object):
def is_coalesced(self) -> bool:
return self.storage.is_coalesced()
def coalesce(self, reduce: str = "add"):
def coalesce(self, reduce: str = "sum"):
return self.from_storage(self.storage.coalesce(reduce))
def fill_cache_(self):
......@@ -252,6 +252,20 @@ class SparseTensor(object):
else:
return bool((value1 == value2).all())
def to_symmetric(self, reduce: str = "sum"):
row, col, value = self.coo()
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
if value is not None:
value = torch.cat([value, value], dim=0)
N = max(self.size(0), self.size(1))
out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=torch.Size([N, N]), is_sorted=False)
out = out.coalesce(reduce)
return out
def detach_(self):
value = self.storage.value()
if value is not None:
......@@ -496,7 +510,7 @@ ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
@torch.jit.ignore
def from_scipy(mat: ScipySparseMatrix) -> SparseTensor:
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
colptr = None
if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long)
......@@ -506,7 +520,9 @@ def from_scipy(mat: ScipySparseMatrix) -> SparseTensor:
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long)
value = torch.from_numpy(mat.data)
value = None
if has_value:
value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2]
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment