Commit 0a4a5579 authored by rusty1s's avatar rusty1s
Browse files

linting

parent 15ff09d5
...@@ -57,7 +57,6 @@ jobs: ...@@ -57,7 +57,6 @@ jobs:
- os: windows - os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101 env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
install: install:
- source script/cuda.sh - source script/cuda.sh
- source script/conda.sh - source script/conda.sh
......
from typing import List from typing import Optional, List
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
...@@ -63,18 +63,10 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -63,18 +63,10 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(rowcounts) == len(tensors): if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0) rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage( storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=sparse_sizes, rowcount=rowcount,
rowptr=rowptr, colptr=None, colcount=None, csr2csc=None,
col=col, csc2csr=None, is_sorted=True)
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
elif dim == 1: elif dim == 1:
...@@ -126,18 +118,10 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -126,18 +118,10 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
if len(colcounts) == len(tensors): if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0) colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage( storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
row=row, sparse_sizes=sparse_sizes, rowcount=None,
rowptr=None, colptr=colptr, colcount=colcount, csr2csc=None,
col=col, csc2csr=None, is_sorted=False)
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=False)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim(): elif dim > 1 and dim < tensors[0].dim():
...@@ -251,16 +235,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -251,16 +235,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
if len(csc2csrs) == len(tensors): if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0) csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage( storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=sparse_sizes, rowcount=rowcount,
rowptr=rowptr, colptr=colptr, colcount=colcount, csr2csc=csr2csc,
col=col, csc2csr=csc2csr, is_sorted=True)
value=value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment