tensor.py 19.8 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
typing  
rusty1s committed
2
from typing import Optional, List, Tuple, Dict, Union, Any
rusty1s's avatar
rusty1s committed
3
4
5
6

import torch
import scipy.sparse

rusty1s's avatar
rusty1s committed
7
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
8
from torch_sparse.utils import is_scalar
rusty1s's avatar
rusty1s committed
9
10


rusty1s's avatar
rusty1s committed
11
@torch.jit.script
rusty1s's avatar
rusty1s committed
12
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
13
14
    storage: SparseStorage

rusty1s's avatar
rusty1s committed
15
    def __init__(self, row: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
16
17
18
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
19
20
                 sparse_sizes: Optional[Tuple[int, int]] = None,
                 is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
21
22
23
24
25
        self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
                                     value=value, sparse_sizes=sparse_sizes,
                                     rowcount=None, colptr=None, colcount=None,
                                     csr2csc=None, csc2csr=None,
                                     is_sorted=is_sorted)
rusty1s's avatar
rusty1s committed
26
27

    @classmethod
rusty1s's avatar
rusty1s committed
28
    def from_storage(self, storage: SparseStorage):
rusty1s's avatar
rusty1s committed
29
        self = SparseTensor.__new__(SparseTensor)
rusty1s's avatar
rusty1s committed
30
        self.storage = storage
rusty1s's avatar
rusty1s committed
31
32
33
        return self

    @classmethod
rusty1s's avatar
rusty1s committed
34
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
rusty1s's avatar
rusty1s committed
35
36
37
38
        if mat.dim() > 2:
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
        else:
            index = mat.nonzero()
rusty1s's avatar
rusty1s committed
39
        index = index.t()
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47
        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

rusty1s's avatar
rusty1s committed
48
49
50
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
51
52

    @classmethod
rusty1s's avatar
rusty1s committed
53
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
rusty1s's avatar
rusty1s committed
54
                                     has_value: bool = True):
rusty1s's avatar
rusty1s committed
55
56
57
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
rusty1s's avatar
rusty1s committed
58
59
60
61
62

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat._values()

rusty1s's avatar
rusty1s committed
63
64
65
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
66
67

    @classmethod
rusty1s's avatar
rusty1s committed
68
69
    def eye(self, M: int, N: Optional[int] = None,
            options: Optional[torch.Tensor] = None, has_value: bool = True,
rusty1s's avatar
rusty1s committed
70
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
75
76
77
        if options is not None:
            row = torch.arange(min(M, N), device=options.device)
        else:
            row = torch.arange(min(M, N))
rusty1s's avatar
rusty1s committed
78
        col = row
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
81
        rowptr = torch.arange(M + 1, dtype=torch.long, device=row.device)
        if M > N:
rusty1s's avatar
rusty1s committed
82
            rowptr[N + 1:] = N
rusty1s's avatar
rusty1s committed
83
84

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
85
        if has_value:
rusty1s's avatar
rusty1s committed
86
            if options is not None:
rusty1s's avatar
rusty1s committed
87
88
                value = torch.ones(row.numel(), dtype=options.dtype,
                                   device=row.device)
rusty1s's avatar
rusty1s committed
89
90
91
92
93
94
95
96
            else:
                value = torch.ones(row.numel(), device=row.device)

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
97
98

        if fill_cache:
rusty1s's avatar
rusty1s committed
99
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
100
            if M > N:
rusty1s's avatar
rusty1s committed
101
102
103
104
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
105
            if N > M:
rusty1s's avatar
rusty1s committed
106
107
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
108
109
            csr2csc = csc2csr = row

rusty1s's avatar
rusty1s committed
110
        storage: SparseStorage = SparseStorage(
rusty1s's avatar
rusty1s committed
111
112
113
            row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N),
            rowcount=rowcount, colptr=colptr, colcount=colcount,
            csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
114

rusty1s's avatar
rusty1s committed
115
116
117
118
119
        self = SparseTensor.__new__(SparseTensor)
        self.storage = storage
        return self

    def copy(self):
rusty1s's avatar
rusty1s committed
120
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
121
122

    def clone(self):
rusty1s's avatar
rusty1s committed
123
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
124

rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
131
132
133
134
    def type_as(self, tensor=torch.Tensor):
        value = self.storage._value
        if value is None or tensor.dtype == value.dtype:
            return self
        return self.from_storage(self.storage.type_as(tensor))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        if tensor.device == self.device():
            return self
        return self.from_storage(self.storage.device_as(tensor, non_blocking))
rusty1s's avatar
rusty1s committed
135
136
137

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
138
139
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
140

rusty1s's avatar
rusty1s committed
141
142
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
143

rusty1s's avatar
rusty1s committed
144
145
146
147
148
149
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
150
151
152

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
153
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
154
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
155

rusty1s's avatar
rusty1s committed
156
    def set_value_(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
157
158
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
159
160
        return self

rusty1s's avatar
rusty1s committed
161
    def set_value(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
162
163
164
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

rusty1s's avatar
rusty1s committed
165
    def sparse_sizes(self) -> Tuple[int, int]:
rusty1s's avatar
rusty1s committed
166
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
167

rusty1s's avatar
rusty1s committed
168
169
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
170

rusty1s's avatar
rusty1s committed
171
    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
rusty1s's avatar
rusty1s committed
172
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
173

rusty1s's avatar
rusty1s committed
174
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
175
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
176

rusty1s's avatar
rusty1s committed
177
    def coalesce(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
178
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
179

rusty1s's avatar
rusty1s committed
180
181
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
182
183
        return self

rusty1s's avatar
rusty1s committed
184
185
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
186
187
188
189
        return self

    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
190
    def fill_value_(self, fill_value: float,
rusty1s's avatar
rusty1s committed
191
192
                    options: Optional[torch.Tensor] = None):
        if options is not None:
rusty1s's avatar
rusty1s committed
193
            value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
rusty1s's avatar
rusty1s committed
194
195
                               device=self.device())
        else:
rusty1s's avatar
rusty1s committed
196
            value = torch.full((self.nnz(), ), fill_value,
rusty1s's avatar
rusty1s committed
197
198
199
                               device=self.device())
        return self.set_value_(value, layout='coo')

rusty1s's avatar
rusty1s committed
200
    def fill_value(self, fill_value: float,
rusty1s's avatar
rusty1s committed
201
202
                   options: Optional[torch.Tensor] = None):
        if options is not None:
rusty1s's avatar
rusty1s committed
203
            value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
rusty1s's avatar
rusty1s committed
204
205
                               device=self.device())
        else:
rusty1s's avatar
rusty1s committed
206
            value = torch.full((self.nnz(), ), fill_value,
rusty1s's avatar
rusty1s committed
207
208
209
210
                               device=self.device())
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
rusty1s's avatar
rusty1s committed
211
        sparse_sizes = self.sparse_sizes()
rusty1s's avatar
rusty1s committed
212
213
        value = self.storage.value()
        if value is not None:
rusty1s's avatar
rusty1s committed
214
215
216
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)
rusty1s's avatar
rusty1s committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
233

rusty1s's avatar
rusty1s committed
234
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
235
236
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
237
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
238
239
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
240
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
241
242
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
243
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
244
245
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
246
247
248
249
250
251
252
253
254
    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

rusty1s's avatar
rusty1s committed
255
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
256
257
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
258
259
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
260
261
            return False

rusty1s's avatar
rusty1s committed
262
263
264
265
266
267
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
268
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
269
            return True
rusty1s's avatar
rusty1s committed
270
271
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
272

rusty1s's avatar
rusty1s committed
273
274
275
276
277
278
279
280
281
    def to_symmetric(self, reduce: str = "sum"):
        row, col, value = self.coo()

        row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
        if value is not None:
            value = torch.cat([value, value], dim=0)

        N = max(self.size(0), self.size(1))

rusty1s's avatar
rusty1s committed
282
283
        out = SparseTensor(row=row, rowptr=None, col=col, value=value,
                           sparse_sizes=(N, N), is_sorted=False)
rusty1s's avatar
rusty1s committed
284
285
286
        out = out.coalesce(reduce)
        return out

rusty1s's avatar
rusty1s committed
287
    def detach_(self):
rusty1s's avatar
rusty1s committed
288
289
290
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
291
292
293
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
294
295
296
297
298
299
300
301
302
303
304
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
305

rusty1s's avatar
rusty1s committed
306
    def requires_grad_(self, requires_grad: bool = True,
rusty1s's avatar
rusty1s committed
307
                       options: Optional[torch.Tensor] = None):
rusty1s's avatar
rusty1s committed
308
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
309
            self.fill_value_(1., options=options)
rusty1s's avatar
rusty1s committed
310

rusty1s's avatar
rusty1s committed
311
312
313
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
314
315
        return self

rusty1s's avatar
rusty1s committed
316
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
317
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
318

rusty1s's avatar
rusty1s committed
319
320
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
321

rusty1s's avatar
rusty1s committed
322
323
324
325
326
    def options(self) -> torch.Tensor:
        value = self.storage.value()
        if value is not None:
            return value
        else:
rusty1s's avatar
rusty1s committed
327
328
            return torch.tensor(0., dtype=torch.float,
                                device=self.storage.col().device)
rusty1s's avatar
rusty1s committed
329
330

    def device(self):
rusty1s's avatar
rusty1s committed
331
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
332
333

    def cpu(self):
rusty1s's avatar
rusty1s committed
334
        return self.device_as(torch.tensor(0.), non_blocking=False)
rusty1s's avatar
rusty1s committed
335

rusty1s's avatar
rusty1s committed
336
    def cuda(self, options: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
337
             non_blocking: bool = False):
rusty1s's avatar
rusty1s committed
338
339
        if options is not None:
            return self.device_as(options, non_blocking)
rusty1s's avatar
rusty1s committed
340
        else:
rusty1s's avatar
rusty1s committed
341
342
            options = torch.tensor(0.).cuda()
            return self.device_as(options, non_blocking)
rusty1s's avatar
rusty1s committed
343

rusty1s's avatar
rusty1s committed
344
345
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
346

rusty1s's avatar
rusty1s committed
347
348
    def dtype(self):
        return self.options().dtype
rusty1s's avatar
rusty1s committed
349

rusty1s's avatar
rusty1s committed
350
351
    def is_floating_point(self) -> bool:
        return torch.is_floating_point(self.options())
rusty1s's avatar
rusty1s committed
352
353

    def bfloat16(self):
rusty1s's avatar
rusty1s committed
354
355
        return self.type_as(
            torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
rusty1s's avatar
rusty1s committed
356
357

    def bool(self):
rusty1s's avatar
rusty1s committed
358
359
        return self.type_as(
            torch.tensor(0, dtype=torch.bool, device=self.device()))
rusty1s's avatar
rusty1s committed
360
361

    def byte(self):
rusty1s's avatar
rusty1s committed
362
363
        return self.type_as(
            torch.tensor(0, dtype=torch.uint8, device=self.device()))
rusty1s's avatar
rusty1s committed
364
365

    def char(self):
rusty1s's avatar
rusty1s committed
366
367
        return self.type_as(
            torch.tensor(0, dtype=torch.int8, device=self.device()))
rusty1s's avatar
rusty1s committed
368
369

    def half(self):
rusty1s's avatar
rusty1s committed
370
371
        return self.type_as(
            torch.tensor(0, dtype=torch.half, device=self.device()))
rusty1s's avatar
rusty1s committed
372
373

    def float(self):
rusty1s's avatar
rusty1s committed
374
375
        return self.type_as(
            torch.tensor(0, dtype=torch.float, device=self.device()))
rusty1s's avatar
rusty1s committed
376
377

    def double(self):
rusty1s's avatar
rusty1s committed
378
379
        return self.type_as(
            torch.tensor(0, dtype=torch.double, device=self.device()))
rusty1s's avatar
rusty1s committed
380
381

    def short(self):
rusty1s's avatar
rusty1s committed
382
383
        return self.type_as(
            torch.tensor(0, dtype=torch.short, device=self.device()))
rusty1s's avatar
rusty1s committed
384
385

    def int(self):
rusty1s's avatar
rusty1s committed
386
387
        return self.type_as(
            torch.tensor(0, dtype=torch.int, device=self.device()))
rusty1s's avatar
rusty1s committed
388
389

    def long(self):
rusty1s's avatar
rusty1s committed
390
391
        return self.type_as(
            torch.tensor(0, dtype=torch.long, device=self.device()))
rusty1s's avatar
rusty1s committed
392
393
394

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
395
    def to_dense(self, options: Optional[torch.Tensor] = None):
rusty1s's avatar
rusty1s committed
396
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
397

rusty1s's avatar
fixes  
rusty1s committed
398
        if value is not None:
rusty1s's avatar
rusty1s committed
399
400
            mat = torch.zeros(self.sizes(), dtype=value.dtype,
                              device=self.device())
rusty1s's avatar
fixes  
rusty1s committed
401
        elif options is not None:
rusty1s's avatar
rusty1s committed
402
403
            mat = torch.zeros(self.sizes(), dtype=options.dtype,
                              device=self.device())
rusty1s's avatar
rusty1s committed
404
405
406
407
408
409
        else:
            mat = torch.zeros(self.sizes(), device=self.device())

        if value is not None:
            mat[row, col] = value
        else:
rusty1s's avatar
rusty1s committed
410
411
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)
rusty1s's avatar
rusty1s committed
412

rusty1s's avatar
rusty1s committed
413
414
        return mat

rusty1s's avatar
rusty1s committed
415
416
    def to_torch_sparse_coo_tensor(self,
                                   options: Optional[torch.Tensor] = None):
rusty1s's avatar
rusty1s committed
417
418
419
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
        if value is None:
rusty1s's avatar
rusty1s committed
420
            if options is not None:
rusty1s's avatar
rusty1s committed
421
422
                value = torch.ones(self.nnz(), dtype=options.dtype,
                                   device=self.device())
rusty1s's avatar
rusty1s committed
423
            else:
rusty1s's avatar
rusty1s committed
424
                value = torch.ones(self.nnz(), device=self.device())
rusty1s's avatar
rusty1s committed
425

rusty1s's avatar
rusty1s committed
426
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
427

rusty1s's avatar
rusty1s committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

# Python Bindings #############################################################

Dtype = Optional[torch.dtype]
Device = Optional[Union[torch.device, str]]


def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()


def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


rusty1s's avatar
typing  
rusty1s committed
443
444
445
def to(self, *args: Optional[List[Any]],
       **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:

rusty1s's avatar
rusty1s committed
446
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
rusty1s's avatar
rusty1s committed
447
448
449
450
451
452
453
454
455

    if dtype is not None:
        self = self.type_as(torch.tensor(0., dtype=dtype))
    if device is not None:
        self = self.device_as(torch.tensor(0., device=device), non_blocking)

    return self


rusty1s's avatar
typing  
rusty1s committed
456
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
rusty1s's avatar
repr  
rusty1s committed
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
    if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


rusty1s's avatar
typing  
rusty1s committed
498
def __repr__(self: SparseTensor) -> str:
rusty1s's avatar
repr  
rusty1s committed
499
500
501
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
rusty1s's avatar
rusty1s committed
502
503
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
504
505

    if value is not None:
rusty1s's avatar
rusty1s committed
506
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
507
508

    infos += [
rusty1s's avatar
rusty1s committed
509
510
        f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
rusty1s's avatar
repr  
rusty1s committed
511
    ]
rusty1s's avatar
rusty1s committed
512

rusty1s's avatar
repr  
rusty1s committed
513
514
515
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
rusty1s's avatar
rusty1s committed
516
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
rusty1s's avatar
repr  
rusty1s committed
517
518


rusty1s's avatar
rusty1s committed
519
520
521
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
repr  
rusty1s committed
522
523
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
524
525
526

# Scipy Conversions ###########################################################

rusty1s's avatar
rusty1s committed
527
528
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
                          csr_matrix, scipy.sparse.csc_matrix]
rusty1s's avatar
rusty1s committed
529
530
531


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
532
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
rusty1s's avatar
rusty1s committed
533
534
535
536
537
538
539
540
541
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
rusty1s's avatar
rusty1s committed
542
543
544
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
rusty1s's avatar
rusty1s committed
545
546
    sparse_sizes = mat.shape[:2]

rusty1s's avatar
rusty1s committed
547
548
549
550
    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
551
552
553
554
555

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
556
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
rusty1s's avatar
rusty1s committed
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy

# Hacky fixes #################################################################

rusty1s's avatar
rusty1s committed
589
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
rusty1s's avatar
rusty1s committed
590
# https://github.com/pytorch/pytorch/pull/31769
rusty1s's avatar
rusty1s committed
591
592
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
rusty1s's avatar
rusty1s committed
593
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR <= 3):
rusty1s's avatar
rusty1s committed
594
595

    def add(self, other):
rusty1s's avatar
rusty1s committed
596
597
598
        if torch.is_tensor(other) or is_scalar(other):
            return self.add(other)
        return NotImplemented
rusty1s's avatar
rusty1s committed
599
600

    def mul(self, other):
rusty1s's avatar
rusty1s committed
601
602
603
        if torch.is_tensor(other) or is_scalar(other):
            return self.mul(other)
        return NotImplemented
rusty1s's avatar
rusty1s committed
604
605

    torch.Tensor.__add__ = add
rusty1s's avatar
rusty1s committed
606
    torch.Tensor.__mul__ = mul