tensor.py 18.6 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
typing  
rusty1s committed
2
from typing import Optional, List, Tuple, Dict, Union, Any
rusty1s's avatar
rusty1s committed
3
4
5
6

import torch
import scipy.sparse

rusty1s's avatar
rusty1s committed
7
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
@torch.jit.script
rusty1s's avatar
rusty1s committed
11
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
12
13
    storage: SparseStorage

rusty1s's avatar
rusty1s committed
14
    def __init__(self, row: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
15
16
17
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
18
19
                 sparse_sizes: Optional[Tuple[int, int]] = None,
                 is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
20
21
22
23
24
        self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
                                     value=value, sparse_sizes=sparse_sizes,
                                     rowcount=None, colptr=None, colcount=None,
                                     csr2csc=None, csc2csr=None,
                                     is_sorted=is_sorted)
rusty1s's avatar
rusty1s committed
25
26

    @classmethod
rusty1s's avatar
rusty1s committed
27
    def from_storage(self, storage: SparseStorage):
rusty1s's avatar
rusty1s committed
28
        self = SparseTensor.__new__(SparseTensor)
rusty1s's avatar
rusty1s committed
29
        self.storage = storage
rusty1s's avatar
rusty1s committed
30
31
        return self

rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39
40
    @classmethod
    def from_edge_index(self, edge_index: torch.Tensor,
                        edge_attr: Optional[torch.Tensor] = None,
                        sparse_sizes: Optional[Tuple[int, int]] = None,
                        is_sorted: bool = False):
        return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
                            value=edge_attr, sparse_sizes=sparse_sizes,
                            is_sorted=is_sorted)

rusty1s's avatar
rusty1s committed
41
    @classmethod
rusty1s's avatar
rusty1s committed
42
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
rusty1s's avatar
rusty1s committed
43
        if mat.dim() > 2:
rusty1s's avatar
rusty1s committed
44
45
            index = mat.abs().sum([i for i in range(2, mat.dim())
                                   ]).nonzero(as_tuple=False)
rusty1s's avatar
rusty1s committed
46
        else:
rusty1s's avatar
rusty1s committed
47
            index = mat.nonzero(as_tuple=False)
rusty1s's avatar
rusty1s committed
48
        index = index.t()
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

rusty1s's avatar
rusty1s committed
57
58
59
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
60
61

    @classmethod
rusty1s's avatar
rusty1s committed
62
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
rusty1s's avatar
rusty1s committed
63
                                     has_value: bool = True):
rusty1s's avatar
rusty1s committed
64
65
66
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
rusty1s's avatar
rusty1s committed
67
68
69

        value: Optional[torch.Tensor] = None
        if has_value:
70
            value = mat.values()
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72
73
74
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
75
76

    @classmethod
rusty1s's avatar
rusty1s committed
77
78
    def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
            dtype: Optional[int] = None, device: Optional[torch.device] = None,
rusty1s's avatar
rusty1s committed
79
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
        row = torch.arange(min(M, N), device=device)
rusty1s's avatar
rusty1s committed
84
        col = row
rusty1s's avatar
rusty1s committed
85

rusty1s's avatar
rusty1s committed
86
        rowptr = torch.arange(M + 1, device=row.device)
rusty1s's avatar
rusty1s committed
87
        if M > N:
rusty1s's avatar
rusty1s committed
88
            rowptr[N + 1:] = N
rusty1s's avatar
rusty1s committed
89
90

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
91
        if has_value:
rusty1s's avatar
rusty1s committed
92
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)
rusty1s's avatar
rusty1s committed
93
94
95
96
97
98

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
99
100

        if fill_cache:
rusty1s's avatar
rusty1s committed
101
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
102
            if M > N:
rusty1s's avatar
rusty1s committed
103
104
105
106
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
107
            if N > M:
rusty1s's avatar
rusty1s committed
108
109
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
110
111
            csr2csc = csc2csr = row

rusty1s's avatar
rusty1s committed
112
        storage: SparseStorage = SparseStorage(
rusty1s's avatar
rusty1s committed
113
114
115
            row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N),
            rowcount=rowcount, colptr=colptr, colcount=colcount,
            csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
rusty1s committed
117
118
119
120
121
        self = SparseTensor.__new__(SparseTensor)
        self.storage = storage
        return self

    def copy(self):
rusty1s's avatar
rusty1s committed
122
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
123
124

    def clone(self):
rusty1s's avatar
rusty1s committed
125
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
126

rusty1s's avatar
typo  
rusty1s committed
127
    def type_as(self, tensor: torch.Tensor):
rusty1s's avatar
rusty1s committed
128
        value = self.storage.value()
rusty1s's avatar
rusty1s committed
129
130
131
132
133
134
135
136
        if value is None or tensor.dtype == value.dtype:
            return self
        return self.from_storage(self.storage.type_as(tensor))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        if tensor.device == self.device():
            return self
        return self.from_storage(self.storage.device_as(tensor, non_blocking))
rusty1s's avatar
rusty1s committed
137
138
139

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
140
141
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
142

rusty1s's avatar
rusty1s committed
143
144
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
145

rusty1s's avatar
rusty1s committed
146
147
148
149
150
151
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
152
153
154

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
155
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
156
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
157

rusty1s's avatar
rusty1s committed
158
    def set_value_(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
159
160
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
161
162
        return self

rusty1s's avatar
rusty1s committed
163
    def set_value(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
164
165
166
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

rusty1s's avatar
rusty1s committed
167
    def sparse_sizes(self) -> Tuple[int, int]:
rusty1s's avatar
rusty1s committed
168
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
169

rusty1s's avatar
rusty1s committed
170
171
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
172

rusty1s's avatar
rusty1s committed
173
    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
rusty1s's avatar
rusty1s committed
174
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
175

rusty1s's avatar
rusty1s committed
176
177
178
179
    def sparse_reshape(self, num_rows: int, num_cols: int):
        return self.from_storage(
            self.storage.sparse_reshape(num_rows, num_cols))

rusty1s's avatar
rusty1s committed
180
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
181
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
182

rusty1s's avatar
rusty1s committed
183
    def coalesce(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
184
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
185

rusty1s's avatar
rusty1s committed
186
187
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
188
189
        return self

rusty1s's avatar
rusty1s committed
190
191
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
192
193
194
195
        return self

    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
196
197
198
    def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
199
200
        return self.set_value_(value, layout='coo')

rusty1s's avatar
rusty1s committed
201
202
203
    def fill_value(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
204
205
206
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
rusty1s's avatar
rusty1s committed
207
        sparse_sizes = self.sparse_sizes()
rusty1s's avatar
rusty1s committed
208
209
        value = self.storage.value()
        if value is not None:
rusty1s's avatar
rusty1s committed
210
211
212
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)
rusty1s's avatar
rusty1s committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
229

rusty1s's avatar
rusty1s committed
230
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
231
232
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
233
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
234
235
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
236
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
237
238
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
239
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
240
241
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
242
243
244
245
    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

rusty1s's avatar
rusty1s committed
246
247
248
249
    def avg_bandwidth(self) -> float:
        row, col, _ = self.coo()
        return float((row - col).abs_().to(torch.float).mean())

rusty1s's avatar
rusty1s committed
250
251
252
253
254
    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

rusty1s's avatar
rusty1s committed
255
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
256
257
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
258
259
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
260
261
            return False

rusty1s's avatar
rusty1s committed
262
263
264
265
266
267
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
268
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
269
            return True
rusty1s's avatar
rusty1s committed
270
271
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
272

rusty1s's avatar
rusty1s committed
273
274
275
276
277
278
279
280
281
    def to_symmetric(self, reduce: str = "sum"):
        row, col, value = self.coo()

        row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
        if value is not None:
            value = torch.cat([value, value], dim=0)

        N = max(self.size(0), self.size(1))

rusty1s's avatar
rusty1s committed
282
283
        out = SparseTensor(row=row, rowptr=None, col=col, value=value,
                           sparse_sizes=(N, N), is_sorted=False)
rusty1s's avatar
rusty1s committed
284
285
286
        out = out.coalesce(reduce)
        return out

rusty1s's avatar
rusty1s committed
287
    def detach_(self):
rusty1s's avatar
rusty1s committed
288
289
290
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
291
292
293
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
294
295
296
297
298
299
300
301
302
303
304
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
305

rusty1s's avatar
rusty1s committed
306
    def requires_grad_(self, requires_grad: bool = True,
rusty1s's avatar
rusty1s committed
307
                       dtype: Optional[int] = None):
rusty1s's avatar
rusty1s committed
308
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
309
            self.fill_value_(1., dtype)
rusty1s's avatar
rusty1s committed
310

rusty1s's avatar
rusty1s committed
311
312
313
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
314
315
        return self

rusty1s's avatar
rusty1s committed
316
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
317
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
318

rusty1s's avatar
rusty1s committed
319
320
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
321
322

    def device(self):
rusty1s's avatar
rusty1s committed
323
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
324
325

    def cpu(self):
rusty1s's avatar
rusty1s committed
326
        return self.device_as(torch.tensor(0), non_blocking=False)
rusty1s's avatar
rusty1s committed
327

rusty1s's avatar
rusty1s committed
328
329
    def cuda(self):
        return self.from_storage(self.storage.cuda())
rusty1s's avatar
rusty1s committed
330

rusty1s's avatar
rusty1s committed
331
332
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
333

rusty1s's avatar
rusty1s committed
334
    def dtype(self):
rusty1s's avatar
rusty1s committed
335
336
        value = self.storage.value()
        return value.dtype if value is not None else torch.float
rusty1s's avatar
rusty1s committed
337

rusty1s's avatar
rusty1s committed
338
    def is_floating_point(self) -> bool:
rusty1s's avatar
rusty1s committed
339
340
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True
rusty1s's avatar
rusty1s committed
341
342

    def bfloat16(self):
rusty1s's avatar
rusty1s committed
343
344
        return self.type_as(
            torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
rusty1s's avatar
rusty1s committed
345
346

    def bool(self):
rusty1s's avatar
rusty1s committed
347
348
        return self.type_as(
            torch.tensor(0, dtype=torch.bool, device=self.device()))
rusty1s's avatar
rusty1s committed
349
350

    def byte(self):
rusty1s's avatar
rusty1s committed
351
352
        return self.type_as(
            torch.tensor(0, dtype=torch.uint8, device=self.device()))
rusty1s's avatar
rusty1s committed
353
354

    def char(self):
rusty1s's avatar
rusty1s committed
355
356
        return self.type_as(
            torch.tensor(0, dtype=torch.int8, device=self.device()))
rusty1s's avatar
rusty1s committed
357
358

    def half(self):
rusty1s's avatar
rusty1s committed
359
360
        return self.type_as(
            torch.tensor(0, dtype=torch.half, device=self.device()))
rusty1s's avatar
rusty1s committed
361
362

    def float(self):
rusty1s's avatar
rusty1s committed
363
364
        return self.type_as(
            torch.tensor(0, dtype=torch.float, device=self.device()))
rusty1s's avatar
rusty1s committed
365
366

    def double(self):
rusty1s's avatar
rusty1s committed
367
368
        return self.type_as(
            torch.tensor(0, dtype=torch.double, device=self.device()))
rusty1s's avatar
rusty1s committed
369
370

    def short(self):
rusty1s's avatar
rusty1s committed
371
372
        return self.type_as(
            torch.tensor(0, dtype=torch.short, device=self.device()))
rusty1s's avatar
rusty1s committed
373
374

    def int(self):
rusty1s's avatar
rusty1s committed
375
376
        return self.type_as(
            torch.tensor(0, dtype=torch.int, device=self.device()))
rusty1s's avatar
rusty1s committed
377
378

    def long(self):
rusty1s's avatar
rusty1s committed
379
380
        return self.type_as(
            torch.tensor(0, dtype=torch.long, device=self.device()))
rusty1s's avatar
rusty1s committed
381
382
383

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
384
    def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
385
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
386

rusty1s's avatar
fixes  
rusty1s committed
387
        if value is not None:
rusty1s's avatar
rusty1s committed
388
389
            mat = torch.zeros(self.sizes(), dtype=value.dtype,
                              device=self.device())
rusty1s's avatar
rusty1s committed
390
        else:
rusty1s's avatar
rusty1s committed
391
            mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
392
393
394
395

        if value is not None:
            mat[row, col] = value
        else:
rusty1s's avatar
rusty1s committed
396
397
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)
rusty1s's avatar
rusty1s committed
398

rusty1s's avatar
rusty1s committed
399
400
        return mat

rusty1s's avatar
typo  
rusty1s committed
401
402
    def to_torch_sparse_coo_tensor(
            self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
403
404
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
405

rusty1s's avatar
rusty1s committed
406
        if value is None:
rusty1s's avatar
rusty1s committed
407
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
408

rusty1s's avatar
rusty1s committed
409
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
410

rusty1s's avatar
rusty1s committed
411
412
413
414
415
416
417
418
419
420
421
422

# Python Bindings #############################################################


def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()


def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


rusty1s's avatar
typing  
rusty1s committed
423
424
425
def to(self, *args: Optional[List[Any]],
       **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:

rusty1s's avatar
rusty1s committed
426
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
rusty1s's avatar
rusty1s committed
427
428
429
430
431
432
433
434
435

    if dtype is not None:
        self = self.type_as(torch.tensor(0., dtype=dtype))
    if device is not None:
        self = self.device_as(torch.tensor(0., device=device), non_blocking)

    return self


rusty1s's avatar
typing  
rusty1s committed
436
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
rusty1s's avatar
repr  
rusty1s committed
437
438
439
440
441
442
443
444
445
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
    if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
rusty1s's avatar
rusty1s committed
446
447
        if isinstance(item, (list, tuple)):
            item = torch.tensor(item, dtype=torch.long, device=self.device())
rusty1s's avatar
repr  
rusty1s committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


rusty1s's avatar
typing  
rusty1s committed
480
def __repr__(self: SparseTensor) -> str:
rusty1s's avatar
repr  
rusty1s committed
481
482
483
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
rusty1s's avatar
rusty1s committed
484
485
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
486
487

    if value is not None:
rusty1s's avatar
rusty1s committed
488
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
489
490

    infos += [
rusty1s's avatar
rusty1s committed
491
492
        f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
rusty1s's avatar
repr  
rusty1s committed
493
    ]
rusty1s's avatar
rusty1s committed
494

rusty1s's avatar
repr  
rusty1s committed
495
496
497
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
rusty1s's avatar
rusty1s committed
498
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
rusty1s's avatar
repr  
rusty1s committed
499
500


rusty1s's avatar
rusty1s committed
501
502
503
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
repr  
rusty1s committed
504
505
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
506
507
508

# Scipy Conversions ###########################################################

rusty1s's avatar
typo  
rusty1s committed
509
510
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
                          scipy.sparse.csc_matrix]
rusty1s's avatar
rusty1s committed
511
512
513


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
514
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
rusty1s's avatar
rusty1s committed
515
516
517
518
519
520
521
522
523
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
rusty1s's avatar
rusty1s committed
524
525
526
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
rusty1s's avatar
rusty1s committed
527
528
    sparse_sizes = mat.shape[:2]

rusty1s's avatar
rusty1s committed
529
530
531
532
    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
533
534
535
536
537

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
538
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
rusty1s's avatar
rusty1s committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy