tensor.py 20 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
rusty1s committed
2
from typing import Optional, List, Tuple, Union
rusty1s's avatar
rusty1s committed
3
4
5
6

import torch
import scipy.sparse

rusty1s's avatar
rusty1s committed
7
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz
# from torch_sparse.diag import remove_diag, set_diag
rusty1s's avatar
rusty1s committed
12
# import torch_sparse.reduce
rusty1s's avatar
rusty1s committed
13
14
15
# from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
rusty1s's avatar
rusty1s committed
16
from torch_sparse.utils import is_scalar
rusty1s's avatar
rusty1s committed
17
18


rusty1s's avatar
rusty1s committed
19
@torch.jit.script
rusty1s's avatar
rusty1s committed
20
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
    storage: SparseStorage

    def __init__(self, row: Optional[torch.Tensor] = None,
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
                 sparse_sizes: List[int] = None, is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
28
        self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
rusty1s's avatar
rusty1s committed
29
30
31
                                     value=value, sparse_sizes=sparse_sizes,
                                     rowcount=None, colptr=None, colcount=None,
                                     csr2csc=None, csc2csr=None,
rusty1s's avatar
rusty1s committed
32
                                     is_sorted=is_sorted)
rusty1s's avatar
rusty1s committed
33
34

    @classmethod
rusty1s's avatar
rusty1s committed
35
    def from_storage(self, storage: SparseStorage):
rusty1s's avatar
rusty1s committed
36
        self = SparseTensor.__new__(SparseTensor)
rusty1s's avatar
rusty1s committed
37
        self.storage = storage
rusty1s's avatar
rusty1s committed
38
39
40
        return self

    @classmethod
rusty1s's avatar
rusty1s committed
41
    def from_dense(self, mat: torch.Tensor):
rusty1s's avatar
rusty1s committed
42
43
44
45
        if mat.dim() > 2:
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
        else:
            index = mat.nonzero()
rusty1s's avatar
rusty1s committed
46
        index = index.t()
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
50
        row, col = index[0], index[1]
        return SparseTensor(row=row, rowptr=None, col=col, value=mat[row, col],
                            sparse_sizes=mat.size()[:2], is_sorted=True)
rusty1s's avatar
rusty1s committed
51
52

    @classmethod
rusty1s's avatar
rusty1s committed
53
54
55
56
57
58
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor):
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
        return SparseTensor(row=row, rowptr=None, col=col, value=mat._values(),
                            sparse_sizes=mat.size()[:2], is_sorted=True)
rusty1s's avatar
rusty1s committed
59
60

    @classmethod
rusty1s's avatar
rusty1s committed
61
62
63
    def eye(self, M: int, N: Optional[int] = None,
            options: Optional[torch.Tensor] = None, has_value: bool = True,
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
64

rusty1s's avatar
rusty1s committed
65
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
68
69
70
        if options is not None:
            row = torch.arange(min(M, N), device=options.device)
        else:
            row = torch.arange(min(M, N))
rusty1s's avatar
rusty1s committed
71
        col = row
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
74
75
76
77
        rowptr = torch.arange(M + 1, dtype=torch.long, device=row.device)
        if M > N:
            rowptr[N + 1:] = M

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
78
        if has_value:
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85
86
87
88
89
            if options is not None:
                value = torch.ones(row.numel(), dtype=options.dtype,
                                   device=row.device)
            else:
                value = torch.ones(row.numel(), device=row.device)

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
90
91

        if fill_cache:
rusty1s's avatar
rusty1s committed
92
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
93
            if M > N:
rusty1s's avatar
rusty1s committed
94
95
96
97
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
98
            if N > M:
rusty1s's avatar
rusty1s committed
99
100
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
101
102
            csr2csc = csc2csr = row

rusty1s's avatar
rusty1s committed
103
104
105
106
107
        storage: SparseStorage = SparseStorage(
            row=row, rowptr=rowptr, col=col, value=value,
            sparse_sizes=torch.Size([M, N]), rowcount=rowcount, colptr=colptr,
            colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr,
            is_sorted=True)
rusty1s's avatar
rusty1s committed
108

rusty1s's avatar
rusty1s committed
109
110
111
112
113
        self = SparseTensor.__new__(SparseTensor)
        self.storage = storage
        return self

    def copy(self):
rusty1s's avatar
rusty1s committed
114
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
115
116

    def clone(self):
rusty1s's avatar
rusty1s committed
117
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
118

rusty1s's avatar
rusty1s committed
119
120
121
122
123
124
125
126
127
128
    def type_as(self, tensor=torch.Tensor):
        value = self.storage._value
        if value is None or tensor.dtype == value.dtype:
            return self
        return self.from_storage(self.storage.type_as(tensor))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        if tensor.device == self.device():
            return self
        return self.from_storage(self.storage.device_as(tensor, non_blocking))
rusty1s's avatar
rusty1s committed
129
130
131

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
132
133
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
134

rusty1s's avatar
rusty1s committed
135
136
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
137

rusty1s's avatar
rusty1s committed
138
139
140
141
142
143
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
144
145
146

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
147
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
148
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
149

rusty1s's avatar
rusty1s committed
150
151
152
    def set_value_(self, value: Optional[torch.Tensor],
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
153
154
        return self

rusty1s's avatar
rusty1s committed
155
156
157
158
159
160
    def set_value(self, value: Optional[torch.Tensor],
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

    def sparse_sizes(self) -> List[int]:
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
161

rusty1s's avatar
rusty1s committed
162
163
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
164

rusty1s's avatar
rusty1s committed
165
166
    def sparse_resize(self, sparse_sizes: List[int]):
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
167

rusty1s's avatar
rusty1s committed
168
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
169
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
170

rusty1s's avatar
rusty1s committed
171
    def coalesce(self, reduce: str = "add"):
rusty1s's avatar
rusty1s committed
172
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
173

rusty1s's avatar
rusty1s committed
174
175
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
176
177
        return self

rusty1s's avatar
rusty1s committed
178
179
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
180
181
182
183
        return self

    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def fill_value_(self, fill_value: float,
                    options: Optional[torch.Tensor] = None):
        if options is not None:
            value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
                               device=self.device())
        else:
            value = torch.full((self.nnz(), ), fill_value,
                               device=self.device())
        return self.set_value_(value, layout='coo')

    def fill_value(self, fill_value: float,
                   options: Optional[torch.Tensor] = None):
        if options is not None:
            value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
                               device=self.device())
        else:
            value = torch.full((self.nnz(), ), fill_value,
                               device=self.device())
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
        sizes = self.sparse_sizes()
        value = self.storage.value()
        if value is not None:
            sizes += value.size()[1:]
        return sizes

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
226

rusty1s's avatar
rusty1s committed
227
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
228
229
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
230
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
231
232
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
233
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
234
235
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
236
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
237
238
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
239
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
240
241
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
242
243
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
244
245
            return False

rusty1s's avatar
rusty1s committed
246
247
248
249
250
251
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
252
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
253
            return True
rusty1s's avatar
rusty1s committed
254
255
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
256
257

    def detach_(self):
rusty1s's avatar
rusty1s committed
258
259
260
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
261
262
263
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
264
265
266
267
268
269
270
271
272
273
274
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
275

rusty1s's avatar
rusty1s committed
276
277
    def requires_grad_(self, requires_grad: bool = True,
                       options: Optional[torch.Tensor] = None):
rusty1s's avatar
rusty1s committed
278
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
279
            self.fill_value_(1., options=options)
rusty1s's avatar
rusty1s committed
280

rusty1s's avatar
rusty1s committed
281
282
283
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
284
285
        return self

rusty1s's avatar
rusty1s committed
286
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
287
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
288

rusty1s's avatar
rusty1s committed
289
290
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
291

rusty1s's avatar
rusty1s committed
292
293
294
295
296
297
    def options(self) -> torch.Tensor:
        value = self.storage.value()
        if value is not None:
            return value
        else:
            return torch.tensor(0., device=self.storage.col().device)
rusty1s's avatar
rusty1s committed
298
299

    def device(self):
rusty1s's avatar
rusty1s committed
300
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
301
302

    def cpu(self):
rusty1s's avatar
rusty1s committed
303
        return self.device_as(torch.tensor(0.), non_blocking=False)
rusty1s's avatar
rusty1s committed
304

rusty1s's avatar
rusty1s committed
305
306
307
    def cuda(self, options=Optional[torch.Tensor], non_blocking: bool = False):
        if options is not None:
            return self.device_as(options, non_blocking)
rusty1s's avatar
rusty1s committed
308
        else:
rusty1s's avatar
rusty1s committed
309
310
            options = torch.tensor(0.).cuda()
            return self.device_as(options, non_blocking)
rusty1s's avatar
rusty1s committed
311

rusty1s's avatar
rusty1s committed
312
313
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
314

rusty1s's avatar
rusty1s committed
315
316
    def dtype(self):
        return self.options().dtype
rusty1s's avatar
rusty1s committed
317

rusty1s's avatar
rusty1s committed
318
319
    def is_floating_point(self) -> bool:
        return torch.is_floating_point(self.options())
rusty1s's avatar
rusty1s committed
320
321

    def bfloat16(self):
rusty1s's avatar
rusty1s committed
322
        return self.type_as(torch.tensor(0, dtype=torch.bfloat16))
rusty1s's avatar
rusty1s committed
323
324

    def bool(self):
rusty1s's avatar
rusty1s committed
325
        return self.type_as(torch.tensor(0, dtype=torch.bool))
rusty1s's avatar
rusty1s committed
326
327

    def byte(self):
rusty1s's avatar
rusty1s committed
328
        return self.type_as(torch.tensor(0, dtype=torch.uint8))
rusty1s's avatar
rusty1s committed
329
330

    def char(self):
rusty1s's avatar
rusty1s committed
331
        return self.type_as(torch.tensor(0, dtype=torch.int8))
rusty1s's avatar
rusty1s committed
332
333

    def half(self):
rusty1s's avatar
rusty1s committed
334
        return self.type_as(torch.tensor(0, dtype=torch.half))
rusty1s's avatar
rusty1s committed
335
336

    def float(self):
rusty1s's avatar
rusty1s committed
337
        return self.type_as(torch.tensor(0, dtype=torch.float))
rusty1s's avatar
rusty1s committed
338
339

    def double(self):
rusty1s's avatar
rusty1s committed
340
        return self.type_as(torch.tensor(0, dtype=torch.double))
rusty1s's avatar
rusty1s committed
341
342

    def short(self):
rusty1s's avatar
rusty1s committed
343
        return self.type_as(torch.tensor(0, dtype=torch.short))
rusty1s's avatar
rusty1s committed
344
345

    def int(self):
rusty1s's avatar
rusty1s committed
346
        return self.type_as(torch.tensor(0, dtype=torch.int))
rusty1s's avatar
rusty1s committed
347
348

    def long(self):
rusty1s's avatar
rusty1s committed
349
        return self.type_as(torch.tensor(0, dtype=torch.long))
rusty1s's avatar
rusty1s committed
350
351
352

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
353
    def to_dense(self, options: Optional[torch.Tensor] = None):
rusty1s's avatar
rusty1s committed
354
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
355
356
357
358
359
360
361
362
363
364
365
366
367

        if options is not None:
            mat = torch.zeros(self.sizes(), dtype=options.dtype,
                              device=self.device())
        else:
            mat = torch.zeros(self.sizes(), device=self.device())

        if value is not None:
            mat[row, col] = value
        else:
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)

rusty1s's avatar
rusty1s committed
368
369
        return mat

rusty1s's avatar
rusty1s committed
370
    def to_torch_sparse_coo_tensor(self, options: Optional[torch.Tensor]):
rusty1s's avatar
rusty1s committed
371
372
373
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
        if value is None:
rusty1s's avatar
rusty1s committed
374
375
376
            if options is not None:
                value = torch.ones(self.nnz(), dtype=options.dtype,
                                   device=self.device())
rusty1s's avatar
rusty1s committed
377
            else:
rusty1s's avatar
rusty1s committed
378
                value = torch.ones(self.nnz(), device=self.device())
rusty1s's avatar
rusty1s committed
379

rusty1s's avatar
rusty1s committed
380
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
381

rusty1s's avatar
repr  
rusty1s committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    # Standard Operators ######################################################

    # def __add__(self, other):
    #     return self.add(other)

    # def __radd__(self, other):
    #     return self.add(other)

    # def __iadd__(self, other):
    #     return self.add_(other)

    # def __mul__(self, other):
    #     return self.mul(other)

    # def __rmul__(self, other):
    #     return self.mul(other)

    # def __imul__(self, other):
    #     return self.mul_(other)

    # def __matmul__(self, other):
    #     return matmul(self, other, reduce='sum')
rusty1s's avatar
rusty1s committed
404

rusty1s's avatar
rusty1s committed
405

rusty1s's avatar
rusty1s committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.index_select = index_select
# SparseTensor.index_select_nnz = index_select_nnz
# SparseTensor.masked_select = masked_select
# SparseTensor.masked_select_nnz = masked_select_nnz
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_ = add_
# SparseTensor.add_nnz = add_nnz
# SparseTensor.add_nnz_ = add_nnz_
# SparseTensor.mul = mul
# SparseTensor.mul_ = mul_
# SparseTensor.mul_nnz = mul_nnz
# SparseTensor.mul_nnz_ = mul_nnz_

# Python Bindings #############################################################

Dtype = Optional[torch.dtype]
Device = Optional[Union[torch.device, str]]


@torch.jit.ignore
def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()


@torch.jit.ignore
def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


@torch.jit.ignore
def to(self, *args, **kwargs):
    dtype: Dtype = getattr(kwargs, 'dtype', None)
    device: Device = getattr(kwargs, 'device', None)
    non_blocking: bool = getattr(kwargs, 'non_blocking', False)

    for arg in args:
        if isinstance(arg, str) or isinstance(arg, torch.device):
            device = arg
        if isinstance(arg, torch.dtype):
            dtype = arg

    if dtype is not None:
        self = self.type_as(torch.tensor(0., dtype=dtype))
    if device is not None:
        self = self.device_as(torch.tensor(0., device=device), non_blocking)

    return self


rusty1s's avatar
repr  
rusty1s committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
@torch.jit.ignore
def __getitem__(self, index):
    raise NotImplementedError
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
    if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


@torch.jit.ignore
def __repr__(self):
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']

    if value is not None:
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']

    infos += [
        f'size={tuple(self.sizes())}, '
        f'nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
    ]
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'


rusty1s's avatar
rusty1s committed
531
532
533
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
repr  
rusty1s committed
534
535
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600

# Scipy Conversions ###########################################################

ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
                          csr_matrix, scipy.sparse.csc_matrix]


@torch.jit.ignore
def from_scipy(mat: ScipySparseMatrix) -> SparseTensor:
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
    value = torch.from_numpy(mat.data)
    sparse_sizes = mat.shape[:2]

    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy

# Hacky fixes #################################################################

# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
# https://github.com/pytorch/pytorch/pull/31769
rusty1s's avatar
rusty1s committed
601
602
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
rusty1s's avatar
typo  
rusty1s committed
603
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR < 4):
rusty1s's avatar
rusty1s committed
604
605

    def add(self, other):
rusty1s's avatar
rusty1s committed
606
607
608
        if torch.is_tensor(other) or is_scalar(other):
            return self.add(other)
        return NotImplemented
rusty1s's avatar
rusty1s committed
609
610

    def mul(self, other):
rusty1s's avatar
rusty1s committed
611
612
613
        if torch.is_tensor(other) or is_scalar(other):
            return self.mul(other)
        return NotImplemented
rusty1s's avatar
rusty1s committed
614
615

    torch.Tensor.__add__ = add
rusty1s's avatar
rusty1s committed
616
    torch.Tensor.__mul__ = mul