tensor.py 21.8 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
rusty1s committed
2
from typing import Any, Dict, List, Optional, Tuple, Union
rusty1s's avatar
rusty1s committed
3

4
import numpy as np
rusty1s's avatar
rusty1s committed
5
import scipy.sparse
rusty1s's avatar
rusty1s committed
6
import torch
rusty1s's avatar
rusty1s committed
7
from torch_scatter import segment_csr
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
rusty1s committed
12
@torch.jit.script
rusty1s's avatar
rusty1s committed
13
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
14
15
    storage: SparseStorage

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
24
25
    def __init__(
        self,
        row: Optional[torch.Tensor] = None,
        rowptr: Optional[torch.Tensor] = None,
        col: Optional[torch.Tensor] = None,
        value: Optional[torch.Tensor] = None,
        sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
        is_sorted: bool = False,
        trust_data: bool = False,
    ):
26
27
28
29
30
31
32
33
34
35
36
37
        self.storage = SparseStorage(
            row=row,
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=sparse_sizes,
            rowcount=None,
            colptr=None,
            colcount=None,
            csr2csc=None,
            csc2csr=None,
            is_sorted=is_sorted,
rusty1s's avatar
rusty1s committed
38
39
            trust_data=trust_data,
        )
rusty1s's avatar
rusty1s committed
40
41

    @classmethod
rusty1s's avatar
rusty1s committed
42
    def from_storage(self, storage: SparseStorage):
43
44
45
46
47
48
49
        out = SparseTensor(
            row=storage._row,
            rowptr=storage._rowptr,
            col=storage._col,
            value=storage._value,
            sparse_sizes=storage._sparse_sizes,
            is_sorted=True,
rusty1s's avatar
rusty1s committed
50
51
            trust_data=True,
        )
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
        out.storage._rowcount = storage._rowcount
        out.storage._colptr = storage._colptr
        out.storage._colcount = storage._colcount
        out.storage._csr2csc = storage._csr2csc
        out.storage._csc2csr = storage._csc2csr
        return out
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
    @classmethod
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
67
    def from_edge_index(
        self,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None,
        sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
        is_sorted: bool = False,
        trust_data: bool = False,
    ):
rusty1s's avatar
rusty1s committed
68
69
        return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
                            value=edge_attr, sparse_sizes=sparse_sizes,
70
                            is_sorted=is_sorted, trust_data=trust_data)
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72
    @classmethod
rusty1s's avatar
rusty1s committed
73
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
rusty1s's avatar
rusty1s committed
74
        if mat.dim() > 2:
rusty1s's avatar
reset  
rusty1s committed
75
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
rusty1s's avatar
rusty1s committed
76
        else:
rusty1s's avatar
reset  
rusty1s committed
77
            index = mat.nonzero()
rusty1s's avatar
rusty1s committed
78
        index = index.t()
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
81
82
83
84
85
86
        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

rusty1s's avatar
rusty1s committed
87
88
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
89
                            is_sorted=True, trust_data=True)
rusty1s's avatar
rusty1s committed
90
91

    @classmethod
rusty1s's avatar
rusty1s committed
92
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
rusty1s's avatar
rusty1s committed
93
                                     has_value: bool = True):
rusty1s's avatar
rusty1s committed
94
95
96
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
rusty1s's avatar
rusty1s committed
97
98
99

        value: Optional[torch.Tensor] = None
        if has_value:
100
            value = mat.values()
rusty1s's avatar
rusty1s committed
101

rusty1s's avatar
rusty1s committed
102
103
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
104
                            is_sorted=True, trust_data=True)
rusty1s's avatar
rusty1s committed
105

rusty1s's avatar
rusty1s committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    @classmethod
    def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
                                     has_value: bool = True):
        rowptr = mat.crow_indices()
        col = mat.col_indices()

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat.values()

        return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True, trust_data=True)

rusty1s's avatar
rusty1s committed
120
    @classmethod
rusty1s's avatar
rusty1s committed
121
122
    def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
            dtype: Optional[int] = None, device: Optional[torch.device] = None,
rusty1s's avatar
rusty1s committed
123
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
124

rusty1s's avatar
rusty1s committed
125
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
126

rusty1s's avatar
rusty1s committed
127
        row = torch.arange(min(M, N), device=device)
rusty1s's avatar
rusty1s committed
128
        col = row
rusty1s's avatar
rusty1s committed
129

rusty1s's avatar
rusty1s committed
130
        rowptr = torch.arange(M + 1, device=row.device)
rusty1s's avatar
rusty1s committed
131
        if M > N:
rusty1s's avatar
rusty1s committed
132
            rowptr[N + 1:] = N
rusty1s's avatar
rusty1s committed
133
134

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
135
        if has_value:
rusty1s's avatar
rusty1s committed
136
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)
rusty1s's avatar
rusty1s committed
137
138
139
140
141
142

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
143
144

        if fill_cache:
rusty1s's avatar
rusty1s committed
145
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
146
            if M > N:
rusty1s's avatar
rusty1s committed
147
148
149
150
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
151
            if N > M:
rusty1s's avatar
rusty1s committed
152
153
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
154
155
            csr2csc = csc2csr = row

156
157
158
159
160
161
162
        out = SparseTensor(
            row=row,
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=(M, N),
            is_sorted=True,
rusty1s's avatar
rusty1s committed
163
164
            trust_data=True,
        )
rusty1s's avatar
rusty1s committed
165
166
167
168
169
170
        out.storage._rowcount = rowcount
        out.storage._colptr = colptr
        out.storage._colcount = colcount
        out.storage._csr2csc = csr2csc
        out.storage._csc2csr = csc2csr
        return out
rusty1s's avatar
rusty1s committed
171
172

    def copy(self):
rusty1s's avatar
rusty1s committed
173
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
174
175

    def clone(self):
rusty1s's avatar
rusty1s committed
176
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
177

178
    def type(self, dtype: torch.dtype, non_blocking: bool = False):
rusty1s's avatar
rusty1s committed
179
        value = self.storage.value()
180
        if value is None or dtype == value.dtype:
rusty1s's avatar
rusty1s committed
181
            return self
rusty1s's avatar
rusty1s committed
182
183
        return self.from_storage(
            self.storage.type(dtype=dtype, non_blocking=non_blocking))
rusty1s's avatar
rusty1s committed
184

185
186
187
188
189
    def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        return self.type(dtype=tensor.dtype, non_blocking=non_blocking)

    def to_device(self, device: torch.device, non_blocking: bool = False):
        if device == self.device():
rusty1s's avatar
rusty1s committed
190
            return self
rusty1s's avatar
rusty1s committed
191
192
        return self.from_storage(
            self.storage.to_device(device=device, non_blocking=non_blocking))
193
194
195

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        return self.to_device(device=tensor.device, non_blocking=non_blocking)
rusty1s's avatar
rusty1s committed
196
197
198

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
199
200
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
201

rusty1s's avatar
rusty1s committed
202
203
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
204

rusty1s's avatar
rusty1s committed
205
206
207
208
209
210
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
211
212
213

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
214
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
215
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
216

rusty1s's avatar
rusty1s committed
217
    def set_value_(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
218
219
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
220
221
        return self

rusty1s's avatar
rusty1s committed
222
    def set_value(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
223
224
225
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

rusty1s's avatar
rusty1s committed
226
    def sparse_sizes(self) -> Tuple[int, int]:
rusty1s's avatar
rusty1s committed
227
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
228

rusty1s's avatar
rusty1s committed
229
230
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
231

rusty1s's avatar
rusty1s committed
232
    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
rusty1s's avatar
rusty1s committed
233
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
234

rusty1s's avatar
rusty1s committed
235
236
237
238
    def sparse_reshape(self, num_rows: int, num_cols: int):
        return self.from_storage(
            self.storage.sparse_reshape(num_rows, num_cols))

rusty1s's avatar
rusty1s committed
239
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
240
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
241

rusty1s's avatar
rusty1s committed
242
    def coalesce(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
243
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
244

rusty1s's avatar
rusty1s committed
245
246
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
247
248
        return self

rusty1s's avatar
rusty1s committed
249
250
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
251
252
        return self

rusty1s's avatar
rusty1s committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    def __eq__(self, other) -> bool:
        if not isinstance(other, self.__class__):
            return False

        if self.sizes() != other.sizes():
            return False

        rowptrA, colA, valueA = self.csr()
        rowptrB, colB, valueB = other.csr()

        if valueA is None and valueB is not None:
            return False
        if valueA is not None and valueB is None:
            return False
        if not torch.equal(rowptrA, rowptrB):
            return False
        if not torch.equal(colA, colB):
            return False
        if valueA is None and valueB is None:
            return True
        return torch.equal(valueA, valueB)

rusty1s's avatar
rusty1s committed
275
276
    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
277
278
279
    def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
280
281
        return self.set_value_(value, layout='coo')

rusty1s's avatar
rusty1s committed
282
283
284
    def fill_value(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
285
286
287
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
rusty1s's avatar
rusty1s committed
288
        sparse_sizes = self.sparse_sizes()
rusty1s's avatar
rusty1s committed
289
290
        value = self.storage.value()
        if value is not None:
rusty1s's avatar
rusty1s committed
291
292
293
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)
rusty1s's avatar
rusty1s committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
310

rusty1s's avatar
rusty1s committed
311
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
312
313
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
314
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
315
316
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
317
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
318
319
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
320
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
321
322
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
323
324
325
326
    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

rusty1s's avatar
rusty1s committed
327
328
329
330
    def avg_bandwidth(self) -> float:
        row, col, _ = self.coo()
        return float((row - col).abs_().to(torch.float).mean())

rusty1s's avatar
rusty1s committed
331
332
333
334
335
    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

rusty1s's avatar
rusty1s committed
336
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
337
338
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
339
340
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
341
342
            return False

rusty1s's avatar
rusty1s committed
343
344
345
346
347
348
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
349
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
350
            return True
rusty1s's avatar
rusty1s committed
351
352
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
353

rusty1s's avatar
rusty1s committed
354
    def to_symmetric(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
355
356
        N = max(self.size(0), self.size(1))

rusty1s's avatar
rusty1s committed
357
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
358
359
360
361
362
363
364
365
366
        idx = col.new_full((2 * col.numel() + 1, ), -1)
        idx[1:row.numel() + 1] = row
        idx[row.numel() + 1:] = col
        idx[1:] *= N
        idx[1:row.numel() + 1] += col
        idx[row.numel() + 1:] += row

        idx, perm = idx.sort()
        mask = idx[1:] > idx[:-1]
rusty1s's avatar
fix  
rusty1s committed
367
368
        perm = perm[1:].sub_(1)
        idx = perm[mask]
rusty1s's avatar
rusty1s committed
369
370

        if value is not None:
rusty1s's avatar
rusty1s committed
371
372
373
374
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
            value = torch.cat([value, value])[perm]
            value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
375

rusty1s's avatar
fix  
rusty1s committed
376
377
        new_row = torch.cat([row, col], dim=0, out=perm)[idx]
        new_col = torch.cat([col, row], dim=0, out=perm)[idx]
rusty1s's avatar
rusty1s committed
378

379
380
381
382
383
384
385
        out = SparseTensor(
            row=new_row,
            rowptr=None,
            col=new_col,
            value=value,
            sparse_sizes=(N, N),
            is_sorted=True,
rusty1s's avatar
rusty1s committed
386
387
            trust_data=True,
        )
rusty1s's avatar
rusty1s committed
388
389
        return out

rusty1s's avatar
rusty1s committed
390
    def detach_(self):
rusty1s's avatar
rusty1s committed
391
392
393
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
394
395
396
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
397
398
399
400
401
402
403
404
405
406
407
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
408

rusty1s's avatar
rusty1s committed
409
    def requires_grad_(self, requires_grad: bool = True,
rusty1s's avatar
rusty1s committed
410
                       dtype: Optional[int] = None):
rusty1s's avatar
rusty1s committed
411
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
412
            self.fill_value_(1., dtype)
rusty1s's avatar
rusty1s committed
413

rusty1s's avatar
rusty1s committed
414
415
416
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
417
418
        return self

rusty1s's avatar
rusty1s committed
419
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
420
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
421

rusty1s's avatar
rusty1s committed
422
423
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
424
425

    def device(self):
rusty1s's avatar
rusty1s committed
426
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
427
428

    def cpu(self):
429
        return self.to_device(device=torch.device('cpu'), non_blocking=False)
rusty1s's avatar
rusty1s committed
430

rusty1s's avatar
rusty1s committed
431
432
    def cuda(self):
        return self.from_storage(self.storage.cuda())
rusty1s's avatar
rusty1s committed
433

rusty1s's avatar
rusty1s committed
434
435
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
436

rusty1s's avatar
rusty1s committed
437
    def dtype(self):
rusty1s's avatar
rusty1s committed
438
439
        value = self.storage.value()
        return value.dtype if value is not None else torch.float
rusty1s's avatar
rusty1s committed
440

rusty1s's avatar
rusty1s committed
441
    def is_floating_point(self) -> bool:
rusty1s's avatar
rusty1s committed
442
443
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True
rusty1s's avatar
rusty1s committed
444
445

    def bfloat16(self):
446
        return self.type(dtype=torch.bfloat16, non_blocking=False)
rusty1s's avatar
rusty1s committed
447
448

    def bool(self):
449
        return self.type(dtype=torch.bool, non_blocking=False)
rusty1s's avatar
rusty1s committed
450
451

    def byte(self):
452
        return self.type(dtype=torch.uint8, non_blocking=False)
rusty1s's avatar
rusty1s committed
453
454

    def char(self):
455
        return self.type(dtype=torch.int8, non_blocking=False)
rusty1s's avatar
rusty1s committed
456
457

    def half(self):
458
        return self.type(dtype=torch.half, non_blocking=False)
rusty1s's avatar
rusty1s committed
459
460

    def float(self):
461
        return self.type(dtype=torch.float, non_blocking=False)
rusty1s's avatar
rusty1s committed
462
463

    def double(self):
464
        return self.type(dtype=torch.double, non_blocking=False)
rusty1s's avatar
rusty1s committed
465
466

    def short(self):
467
        return self.type(dtype=torch.short, non_blocking=False)
rusty1s's avatar
rusty1s committed
468
469

    def int(self):
470
        return self.type(dtype=torch.int, non_blocking=False)
rusty1s's avatar
rusty1s committed
471
472

    def long(self):
473
        return self.type(dtype=torch.long, non_blocking=False)
rusty1s's avatar
rusty1s committed
474
475
476

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
477
    def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
478
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
479

rusty1s's avatar
fixes  
rusty1s committed
480
        if value is not None:
rusty1s's avatar
rusty1s committed
481
482
            mat = torch.zeros(self.sizes(), dtype=value.dtype,
                              device=self.device())
rusty1s's avatar
rusty1s committed
483
        else:
rusty1s's avatar
rusty1s committed
484
            mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
485
486
487
488

        if value is not None:
            mat[row, col] = value
        else:
rusty1s's avatar
rusty1s committed
489
490
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)
rusty1s's avatar
rusty1s committed
491

rusty1s's avatar
rusty1s committed
492
493
        return mat

rusty1s's avatar
typo  
rusty1s committed
494
495
    def to_torch_sparse_coo_tensor(
            self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
496
497
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
498

rusty1s's avatar
rusty1s committed
499
        if value is None:
rusty1s's avatar
rusty1s committed
500
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
501

rusty1s's avatar
rusty1s committed
502
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
503

rusty1s's avatar
rusty1s committed
504
505
506
507
508
509
510
511
512
    def to_torch_sparse_csr_tensor(
            self, dtype: Optional[int] = None) -> torch.Tensor:
        rowptr, col, value = self.csr()

        if value is None:
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())

        return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())

rusty1s's avatar
rusty1s committed
513
514
515
516
517
518

# Python Bindings #############################################################


def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()
rusty1s's avatar
typo  
rusty1s committed
519
    return self
rusty1s's avatar
rusty1s committed
520
521
522
523
524
525


def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


rusty1s's avatar
typing  
rusty1s committed
526
527
def to(self, *args: Optional[List[Any]],
       **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
rusty1s's avatar
rusty1s committed
528
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
rusty1s's avatar
rusty1s committed
529
530

    if dtype is not None:
531
        self = self.type(dtype=dtype, non_blocking=non_blocking)
rusty1s's avatar
rusty1s committed
532
    if device is not None:
533
        self = self.to_device(device=device, non_blocking=non_blocking)
rusty1s's avatar
rusty1s committed
534
535
536
537

    return self


rusty1s's avatar
rusty1s committed
538
539
540
541
542
543
544
545
546
def cpu(self) -> SparseTensor:
    return self.device_as(torch.tensor(0., device='cpu'))


def cuda(self, device: Optional[Union[int, str]] = None,
         non_blocking: bool = False):
    return self.device_as(torch.tensor(0., device=device or 'cuda'))


rusty1s's avatar
typing  
rusty1s committed
547
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
rusty1s's avatar
repr  
rusty1s committed
548
549
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
550
551
552
553
    if len([
            i for i in index
            if not isinstance(i, (torch.Tensor, np.ndarray)) and i == ...
    ]) > 1:
rusty1s's avatar
repr  
rusty1s committed
554
555
556
557
558
559
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
rusty1s's avatar
rusty1s committed
560
        if isinstance(item, (list, tuple)):
561
562
563
564
            item = torch.tensor(item, device=self.device())
        if isinstance(item, np.ndarray):
            item = torch.from_numpy(item).to(self.device())

rusty1s's avatar
repr  
rusty1s committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


rusty1s's avatar
typing  
rusty1s committed
597
def __repr__(self: SparseTensor) -> str:
rusty1s's avatar
repr  
rusty1s committed
598
599
600
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
rusty1s's avatar
rusty1s committed
601
602
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
603
604

    if value is not None:
rusty1s's avatar
rusty1s committed
605
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
606
607

    infos += [
rusty1s's avatar
rusty1s committed
608
609
        f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
rusty1s's avatar
repr  
rusty1s committed
610
    ]
rusty1s's avatar
rusty1s committed
611

rusty1s's avatar
repr  
rusty1s committed
612
613
614
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
rusty1s's avatar
rusty1s committed
615
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
rusty1s's avatar
repr  
rusty1s committed
616
617


rusty1s's avatar
rusty1s committed
618
619
620
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
rusty1s committed
621
622
SparseTensor.cpu = cpu
SparseTensor.cuda = cuda
rusty1s's avatar
repr  
rusty1s committed
623
624
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
625
626
627

# Scipy Conversions ###########################################################

rusty1s's avatar
typo  
rusty1s committed
628
629
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
                          scipy.sparse.csc_matrix]
rusty1s's avatar
rusty1s committed
630
631
632


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
633
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
rusty1s's avatar
rusty1s committed
634
635
636
637
638
639
640
641
642
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
rusty1s's avatar
rusty1s committed
643
644
645
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
rusty1s's avatar
rusty1s committed
646
647
    sparse_sizes = mat.shape[:2]

rusty1s's avatar
rusty1s committed
648
649
650
651
    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
652
653
654
655
656

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
657
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
rusty1s's avatar
rusty1s committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy