Commit 0a4a5579 authored by rusty1s's avatar rusty1s
Browse files

linting

parent 15ff09d5
......@@ -57,7 +57,6 @@ jobs:
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
install:
- source script/cuda.sh
- source script/conda.sh
......
from typing import List
from typing import Optional, List
import torch
from torch_sparse.storage import SparseStorage
......@@ -63,18 +63,10 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
elif dim == 1:
......@@ -126,18 +118,10 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=False)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim():
......@@ -251,16 +235,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment