Commit 1be2aaf7 authored by rusty1s's avatar rusty1s
Browse files

cpu fixes

parent 516527f1
...@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = { ...@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
} \ } \
}() }()
template <typename scalar_t, ReductionType REDUCE> struct Reducer { template <typename scalar_t> struct Reducer {
static inline scalar_t init() { static inline scalar_t init(ReductionType REDUCE) {
if (REDUCE == MUL || REDUCE == DIV) if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1; return (scalar_t)1;
else if (REDUCE == MIN) else if (REDUCE == MIN)
...@@ -52,8 +52,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -52,8 +52,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
return (scalar_t)0; return (scalar_t)0;
} }
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, static inline void update(ReductionType REDUCE, scalar_t *val,
int64_t new_arg) { scalar_t new_val, int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN) if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val; *val = *val + new_val;
else if (REDUCE == MUL) else if (REDUCE == MUL)
...@@ -67,8 +67,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -67,8 +67,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline void write(scalar_t *address, scalar_t val, static inline void write(ReductionType REDUCE, scalar_t *address,
int64_t *arg_address, int64_t arg, int count) { scalar_t val, int64_t *arg_address, int64_t arg,
int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val; *address = val;
else if (REDUCE == MEAN) else if (REDUCE == MEAN)
......
...@@ -63,7 +63,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -63,7 +63,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
row_start = rowptr_data[m], row_end = rowptr_data[m + 1]; row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init(); vals[k] = Reducer<scalar_t>::init(REDUCE);
auto offset = b * N * K; auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) { for (auto e = row_start; e < row_end; e++) {
...@@ -72,19 +72,20 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -72,19 +72,20 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
val = value_data[e]; val = value_data[e];
for (auto k = 0; k < K; k++) { for (auto k = 0; k < K; k++) {
if (HAS_VALUE) if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t>::update(REDUCE, &vals[k],
&vals[k], val * mat_data[offset + c * K + k], &args[k], val * mat_data[offset + c * K + k],
e); &args[k], e);
else else
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t>::update(REDUCE, &vals[k],
&vals[k], mat_data[offset + c * K + k], &args[k], e); mat_data[offset + c * K + k],
&args[k], e);
} }
} }
offset = b * M * K + m * K; offset = b * M * K + m * K;
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k], Reducer<scalar_t>::write(REDUCE, out_data + offset + k, vals[k],
arg_out_data + offset + k, arg_out_data + offset + k, args[k],
args[k], row_end - row_start); row_end - row_start);
} }
} }
}); });
......
...@@ -166,7 +166,7 @@ class SparseTensor(object): ...@@ -166,7 +166,7 @@ class SparseTensor(object):
def is_coalesced(self) -> bool: def is_coalesced(self) -> bool:
return self.storage.is_coalesced() return self.storage.is_coalesced()
def coalesce(self, reduce: str = "add"): def coalesce(self, reduce: str = "sum"):
return self.from_storage(self.storage.coalesce(reduce)) return self.from_storage(self.storage.coalesce(reduce))
def fill_cache_(self): def fill_cache_(self):
...@@ -252,6 +252,20 @@ class SparseTensor(object): ...@@ -252,6 +252,20 @@ class SparseTensor(object):
else: else:
return bool((value1 == value2).all()) return bool((value1 == value2).all())
def to_symmetric(self, reduce: str = "sum"):
row, col, value = self.coo()
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
if value is not None:
value = torch.cat([value, value], dim=0)
N = max(self.size(0), self.size(1))
out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=torch.Size([N, N]), is_sorted=False)
out = out.coalesce(reduce)
return out
def detach_(self): def detach_(self):
value = self.storage.value() value = self.storage.value()
if value is not None: if value is not None:
...@@ -496,7 +510,7 @@ ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse. ...@@ -496,7 +510,7 @@ ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
@torch.jit.ignore @torch.jit.ignore
def from_scipy(mat: ScipySparseMatrix) -> SparseTensor: def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
colptr = None colptr = None
if isinstance(mat, scipy.sparse.csc_matrix): if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long) colptr = torch.from_numpy(mat.indptr).to(torch.long)
...@@ -506,7 +520,9 @@ def from_scipy(mat: ScipySparseMatrix) -> SparseTensor: ...@@ -506,7 +520,9 @@ def from_scipy(mat: ScipySparseMatrix) -> SparseTensor:
mat = mat.tocoo() mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long) row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long) col = torch.from_numpy(mat.col).to(torch.long)
value = torch.from_numpy(mat.data) value = None
if has_value:
value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2] sparse_sizes = mat.shape[:2]
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment