tensor.py 21.1 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
typing  
rusty1s committed
2
from typing import Optional, List, Tuple, Dict, Union, Any
rusty1s's avatar
rusty1s committed
3
4

import torch
5
import numpy as np
rusty1s's avatar
rusty1s committed
6
import scipy.sparse
rusty1s's avatar
rusty1s committed
7
from torch_scatter import segment_csr
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
rusty1s committed
12
@torch.jit.script
rusty1s's avatar
rusty1s committed
13
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
14
15
    storage: SparseStorage

rusty1s's avatar
rusty1s committed
16
    def __init__(self, row: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
17
18
19
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
rusty1s's avatar
update  
rusty1s committed
20
21
                 sparse_sizes: Optional[Tuple[Optional[int],
                                              Optional[int]]] = None,
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
                 is_sorted: bool = False,
                 trust_data: bool = False):
        self.storage = SparseStorage(
            row=row,
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=sparse_sizes,
            rowcount=None,
            colptr=None,
            colcount=None,
            csr2csc=None,
            csc2csr=None,
            is_sorted=is_sorted,
            trust_data=trust_data)
rusty1s's avatar
rusty1s committed
37
38

    @classmethod
rusty1s's avatar
rusty1s committed
39
    def from_storage(self, storage: SparseStorage):
40
41
42
43
44
45
46
47
        out = SparseTensor(
            row=storage._row,
            rowptr=storage._rowptr,
            col=storage._col,
            value=storage._value,
            sparse_sizes=storage._sparse_sizes,
            is_sorted=True,
            trust_data=True)
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
        out.storage._rowcount = storage._rowcount
        out.storage._colptr = storage._colptr
        out.storage._colcount = storage._colcount
        out.storage._csr2csc = storage._csr2csc
        out.storage._csc2csr = storage._csc2csr
        return out
rusty1s's avatar
rusty1s committed
54

rusty1s's avatar
rusty1s committed
55
56
57
    @classmethod
    def from_edge_index(self, edge_index: torch.Tensor,
                        edge_attr: Optional[torch.Tensor] = None,
rusty1s's avatar
update  
rusty1s committed
58
59
                        sparse_sizes: Optional[Tuple[Optional[int],
                                                     Optional[int]]] = None,
60
61
                        is_sorted: bool = False,
                        trust_data: bool = False):
rusty1s's avatar
rusty1s committed
62
63
        return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
                            value=edge_attr, sparse_sizes=sparse_sizes,
64
                            is_sorted=is_sorted, trust_data=trust_data)
rusty1s's avatar
rusty1s committed
65

rusty1s's avatar
rusty1s committed
66
    @classmethod
rusty1s's avatar
rusty1s committed
67
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
rusty1s's avatar
rusty1s committed
68
        if mat.dim() > 2:
rusty1s's avatar
reset  
rusty1s committed
69
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
rusty1s's avatar
rusty1s committed
70
        else:
rusty1s's avatar
reset  
rusty1s committed
71
            index = mat.nonzero()
rusty1s's avatar
rusty1s committed
72
        index = index.t()
rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
75
76
77
78
79
80
        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

rusty1s's avatar
rusty1s committed
81
82
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
83
                            is_sorted=True, trust_data=True)
rusty1s's avatar
rusty1s committed
84
85

    @classmethod
rusty1s's avatar
rusty1s committed
86
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
rusty1s's avatar
rusty1s committed
87
                                     has_value: bool = True):
rusty1s's avatar
rusty1s committed
88
89
90
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
rusty1s's avatar
rusty1s committed
91
92
93

        value: Optional[torch.Tensor] = None
        if has_value:
94
            value = mat.values()
rusty1s's avatar
rusty1s committed
95

rusty1s's avatar
rusty1s committed
96
97
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
98
                            is_sorted=True, trust_data=True)
rusty1s's avatar
rusty1s committed
99
100

    @classmethod
rusty1s's avatar
rusty1s committed
101
102
    def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
            dtype: Optional[int] = None, device: Optional[torch.device] = None,
rusty1s's avatar
rusty1s committed
103
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
104

rusty1s's avatar
rusty1s committed
105
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
106

rusty1s's avatar
rusty1s committed
107
        row = torch.arange(min(M, N), device=device)
rusty1s's avatar
rusty1s committed
108
        col = row
rusty1s's avatar
rusty1s committed
109

rusty1s's avatar
rusty1s committed
110
        rowptr = torch.arange(M + 1, device=row.device)
rusty1s's avatar
rusty1s committed
111
        if M > N:
rusty1s's avatar
rusty1s committed
112
            rowptr[N + 1:] = N
rusty1s's avatar
rusty1s committed
113
114

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
115
        if has_value:
rusty1s's avatar
rusty1s committed
116
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)
rusty1s's avatar
rusty1s committed
117
118
119
120
121
122

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
123
124

        if fill_cache:
rusty1s's avatar
rusty1s committed
125
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
126
            if M > N:
rusty1s's avatar
rusty1s committed
127
128
129
130
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
131
            if N > M:
rusty1s's avatar
rusty1s committed
132
133
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
134
135
            csr2csc = csc2csr = row

136
137
138
139
140
141
142
143
        out = SparseTensor(
            row=row,
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=(M, N),
            is_sorted=True,
            trust_data=True)
rusty1s's avatar
rusty1s committed
144
145
146
147
148
149
        out.storage._rowcount = rowcount
        out.storage._colptr = colptr
        out.storage._colcount = colcount
        out.storage._csr2csc = csr2csc
        out.storage._csc2csr = csc2csr
        return out
rusty1s's avatar
rusty1s committed
150
151

    def copy(self):
rusty1s's avatar
rusty1s committed
152
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
153
154

    def clone(self):
rusty1s's avatar
rusty1s committed
155
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
156

157
    def type(self, dtype: torch.dtype, non_blocking: bool = False):
rusty1s's avatar
rusty1s committed
158
        value = self.storage.value()
159
        if value is None or dtype == value.dtype:
rusty1s's avatar
rusty1s committed
160
            return self
161
162
        return self.from_storage(self.storage.type(
            dtype=dtype, non_blocking=non_blocking))
rusty1s's avatar
rusty1s committed
163

164
165
166
167
168
    def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        return self.type(dtype=tensor.dtype, non_blocking=non_blocking)

    def to_device(self, device: torch.device, non_blocking: bool = False):
        if device == self.device():
rusty1s's avatar
rusty1s committed
169
            return self
170
171
172
173
174
        return self.from_storage(self.storage.to_device(
            device=device, non_blocking=non_blocking))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        return self.to_device(device=tensor.device, non_blocking=non_blocking)
rusty1s's avatar
rusty1s committed
175
176
177

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
178
179
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
180

rusty1s's avatar
rusty1s committed
181
182
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
183

rusty1s's avatar
rusty1s committed
184
185
186
187
188
189
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
190
191
192

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
193
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
194
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
195

rusty1s's avatar
rusty1s committed
196
    def set_value_(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
197
198
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
199
200
        return self

rusty1s's avatar
rusty1s committed
201
    def set_value(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
202
203
204
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

rusty1s's avatar
rusty1s committed
205
    def sparse_sizes(self) -> Tuple[int, int]:
rusty1s's avatar
rusty1s committed
206
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
207

rusty1s's avatar
rusty1s committed
208
209
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
210

rusty1s's avatar
rusty1s committed
211
    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
rusty1s's avatar
rusty1s committed
212
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
213

rusty1s's avatar
rusty1s committed
214
215
216
217
    def sparse_reshape(self, num_rows: int, num_cols: int):
        return self.from_storage(
            self.storage.sparse_reshape(num_rows, num_cols))

rusty1s's avatar
rusty1s committed
218
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
219
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
220

rusty1s's avatar
rusty1s committed
221
    def coalesce(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
222
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
223

rusty1s's avatar
rusty1s committed
224
225
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
226
227
        return self

rusty1s's avatar
rusty1s committed
228
229
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
230
231
        return self

rusty1s's avatar
rusty1s committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    def __eq__(self, other) -> bool:
        if not isinstance(other, self.__class__):
            return False

        if self.sizes() != other.sizes():
            return False

        rowptrA, colA, valueA = self.csr()
        rowptrB, colB, valueB = other.csr()

        if valueA is None and valueB is not None:
            return False
        if valueA is not None and valueB is None:
            return False
        if not torch.equal(rowptrA, rowptrB):
            return False
        if not torch.equal(colA, colB):
            return False
        if valueA is None and valueB is None:
            return True
        return torch.equal(valueA, valueB)

rusty1s's avatar
rusty1s committed
254
255
    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
256
257
258
    def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
259
260
        return self.set_value_(value, layout='coo')

rusty1s's avatar
rusty1s committed
261
262
263
    def fill_value(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
264
265
266
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
rusty1s's avatar
rusty1s committed
267
        sparse_sizes = self.sparse_sizes()
rusty1s's avatar
rusty1s committed
268
269
        value = self.storage.value()
        if value is not None:
rusty1s's avatar
rusty1s committed
270
271
272
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)
rusty1s's avatar
rusty1s committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
289

rusty1s's avatar
rusty1s committed
290
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
291
292
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
293
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
294
295
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
296
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
297
298
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
299
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
300
301
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
302
303
304
305
    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

rusty1s's avatar
rusty1s committed
306
307
308
309
    def avg_bandwidth(self) -> float:
        row, col, _ = self.coo()
        return float((row - col).abs_().to(torch.float).mean())

rusty1s's avatar
rusty1s committed
310
311
312
313
314
    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

rusty1s's avatar
rusty1s committed
315
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
316
317
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
318
319
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
320
321
            return False

rusty1s's avatar
rusty1s committed
322
323
324
325
326
327
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
328
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
329
            return True
rusty1s's avatar
rusty1s committed
330
331
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
332

rusty1s's avatar
rusty1s committed
333
    def to_symmetric(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
334
335
        N = max(self.size(0), self.size(1))

rusty1s's avatar
rusty1s committed
336
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
337
338
339
340
341
342
343
344
345
        idx = col.new_full((2 * col.numel() + 1, ), -1)
        idx[1:row.numel() + 1] = row
        idx[row.numel() + 1:] = col
        idx[1:] *= N
        idx[1:row.numel() + 1] += col
        idx[row.numel() + 1:] += row

        idx, perm = idx.sort()
        mask = idx[1:] > idx[:-1]
rusty1s's avatar
fix  
rusty1s committed
346
347
        perm = perm[1:].sub_(1)
        idx = perm[mask]
rusty1s's avatar
rusty1s committed
348
349

        if value is not None:
rusty1s's avatar
rusty1s committed
350
351
352
353
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
            value = torch.cat([value, value])[perm]
            value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
354

rusty1s's avatar
fix  
rusty1s committed
355
356
        new_row = torch.cat([row, col], dim=0, out=perm)[idx]
        new_col = torch.cat([col, row], dim=0, out=perm)[idx]
rusty1s's avatar
rusty1s committed
357

358
359
360
361
362
363
364
365
        out = SparseTensor(
            row=new_row,
            rowptr=None,
            col=new_col,
            value=value,
            sparse_sizes=(N, N),
            is_sorted=True,
            trust_data=True)
rusty1s's avatar
rusty1s committed
366
367
        return out

rusty1s's avatar
rusty1s committed
368
    def detach_(self):
rusty1s's avatar
rusty1s committed
369
370
371
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
372
373
374
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
375
376
377
378
379
380
381
382
383
384
385
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
386

rusty1s's avatar
rusty1s committed
387
    def requires_grad_(self, requires_grad: bool = True,
rusty1s's avatar
rusty1s committed
388
                       dtype: Optional[int] = None):
rusty1s's avatar
rusty1s committed
389
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
390
            self.fill_value_(1., dtype)
rusty1s's avatar
rusty1s committed
391

rusty1s's avatar
rusty1s committed
392
393
394
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
395
396
        return self

rusty1s's avatar
rusty1s committed
397
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
398
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
399

rusty1s's avatar
rusty1s committed
400
401
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
402
403

    def device(self):
rusty1s's avatar
rusty1s committed
404
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
405
406

    def cpu(self):
407
        return self.to_device(device=torch.device('cpu'), non_blocking=False)
rusty1s's avatar
rusty1s committed
408

rusty1s's avatar
rusty1s committed
409
410
    def cuda(self):
        return self.from_storage(self.storage.cuda())
rusty1s's avatar
rusty1s committed
411

rusty1s's avatar
rusty1s committed
412
413
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
414

rusty1s's avatar
rusty1s committed
415
    def dtype(self):
rusty1s's avatar
rusty1s committed
416
417
        value = self.storage.value()
        return value.dtype if value is not None else torch.float
rusty1s's avatar
rusty1s committed
418

rusty1s's avatar
rusty1s committed
419
    def is_floating_point(self) -> bool:
rusty1s's avatar
rusty1s committed
420
421
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True
rusty1s's avatar
rusty1s committed
422
423

    def bfloat16(self):
424
        return self.type(dtype=torch.bfloat16, non_blocking=False)
rusty1s's avatar
rusty1s committed
425
426

    def bool(self):
427
        return self.type(dtype=torch.bool, non_blocking=False)
rusty1s's avatar
rusty1s committed
428
429

    def byte(self):
430
        return self.type(dtype=torch.uint8, non_blocking=False)
rusty1s's avatar
rusty1s committed
431
432

    def char(self):
433
        return self.type(dtype=torch.int8, non_blocking=False)
rusty1s's avatar
rusty1s committed
434
435

    def half(self):
436
        return self.type(dtype=torch.half, non_blocking=False)
rusty1s's avatar
rusty1s committed
437
438

    def float(self):
439
        return self.type(dtype=torch.float, non_blocking=False)
rusty1s's avatar
rusty1s committed
440
441

    def double(self):
442
        return self.type(dtype=torch.double, non_blocking=False)
rusty1s's avatar
rusty1s committed
443
444

    def short(self):
445
        return self.type(dtype=torch.short, non_blocking=False)
rusty1s's avatar
rusty1s committed
446
447

    def int(self):
448
        return self.type(dtype=torch.int, non_blocking=False)
rusty1s's avatar
rusty1s committed
449
450

    def long(self):
451
        return self.type(dtype=torch.long, non_blocking=False)
rusty1s's avatar
rusty1s committed
452
453
454

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
455
    def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
456
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
457

rusty1s's avatar
fixes  
rusty1s committed
458
        if value is not None:
rusty1s's avatar
rusty1s committed
459
460
            mat = torch.zeros(self.sizes(), dtype=value.dtype,
                              device=self.device())
rusty1s's avatar
rusty1s committed
461
        else:
rusty1s's avatar
rusty1s committed
462
            mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
463
464
465
466

        if value is not None:
            mat[row, col] = value
        else:
rusty1s's avatar
rusty1s committed
467
468
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)
rusty1s's avatar
rusty1s committed
469

rusty1s's avatar
rusty1s committed
470
471
        return mat

rusty1s's avatar
typo  
rusty1s committed
472
473
    def to_torch_sparse_coo_tensor(
            self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
474
475
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
476

rusty1s's avatar
rusty1s committed
477
        if value is None:
rusty1s's avatar
rusty1s committed
478
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
479

rusty1s's avatar
rusty1s committed
480
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
481

rusty1s's avatar
rusty1s committed
482
483
484
485
486
487

# Python Bindings #############################################################


def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()
rusty1s's avatar
typo  
rusty1s committed
488
    return self
rusty1s's avatar
rusty1s committed
489
490
491
492
493
494


def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


rusty1s's avatar
typing  
rusty1s committed
495
496
def to(self, *args: Optional[List[Any]],
       **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
rusty1s's avatar
rusty1s committed
497
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
rusty1s's avatar
rusty1s committed
498
499

    if dtype is not None:
500
        self = self.type(dtype=dtype, non_blocking=non_blocking)
rusty1s's avatar
rusty1s committed
501
    if device is not None:
502
        self = self.to_device(device=device, non_blocking=non_blocking)
rusty1s's avatar
rusty1s committed
503
504
505
506

    return self


rusty1s's avatar
rusty1s committed
507
508
509
510
511
512
513
514
515
def cpu(self) -> SparseTensor:
    return self.device_as(torch.tensor(0., device='cpu'))


def cuda(self, device: Optional[Union[int, str]] = None,
         non_blocking: bool = False):
    return self.device_as(torch.tensor(0., device=device or 'cuda'))


rusty1s's avatar
typing  
rusty1s committed
516
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
rusty1s's avatar
repr  
rusty1s committed
517
518
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
519
520
521
522
    if len([
            i for i in index
            if not isinstance(i, (torch.Tensor, np.ndarray)) and i == ...
    ]) > 1:
rusty1s's avatar
repr  
rusty1s committed
523
524
525
526
527
528
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
rusty1s's avatar
rusty1s committed
529
        if isinstance(item, (list, tuple)):
530
531
532
533
            item = torch.tensor(item, device=self.device())
        if isinstance(item, np.ndarray):
            item = torch.from_numpy(item).to(self.device())

rusty1s's avatar
repr  
rusty1s committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


rusty1s's avatar
typing  
rusty1s committed
566
def __repr__(self: SparseTensor) -> str:
rusty1s's avatar
repr  
rusty1s committed
567
568
569
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
rusty1s's avatar
rusty1s committed
570
571
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
572
573

    if value is not None:
rusty1s's avatar
rusty1s committed
574
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
575
576

    infos += [
rusty1s's avatar
rusty1s committed
577
578
        f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
rusty1s's avatar
repr  
rusty1s committed
579
    ]
rusty1s's avatar
rusty1s committed
580

rusty1s's avatar
repr  
rusty1s committed
581
582
583
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
rusty1s's avatar
rusty1s committed
584
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
rusty1s's avatar
repr  
rusty1s committed
585
586


rusty1s's avatar
rusty1s committed
587
588
589
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
rusty1s committed
590
591
SparseTensor.cpu = cpu
SparseTensor.cuda = cuda
rusty1s's avatar
repr  
rusty1s committed
592
593
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
594
595
596

# Scipy Conversions ###########################################################

rusty1s's avatar
typo  
rusty1s committed
597
598
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
                          scipy.sparse.csc_matrix]
rusty1s's avatar
rusty1s committed
599
600
601


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
602
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
rusty1s's avatar
rusty1s committed
603
604
605
606
607
608
609
610
611
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
rusty1s's avatar
rusty1s committed
612
613
614
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
rusty1s's avatar
rusty1s committed
615
616
    sparse_sizes = mat.shape[:2]

rusty1s's avatar
rusty1s committed
617
618
619
620
    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
621
622
623
624
625

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
626
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
rusty1s's avatar
rusty1s committed
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy