tensor.py 20.4 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
typing  
rusty1s committed
2
from typing import Optional, List, Tuple, Dict, Union, Any
rusty1s's avatar
rusty1s committed
3
4
5

import torch
import scipy.sparse
rusty1s's avatar
rusty1s committed
6
from torch_scatter import segment_csr
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
9
10


rusty1s's avatar
rusty1s committed
11
@torch.jit.script
rusty1s's avatar
rusty1s committed
12
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
13
14
    storage: SparseStorage

rusty1s's avatar
rusty1s committed
15
    def __init__(self, row: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
16
17
18
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
19
20
                 sparse_sizes: Optional[Tuple[int, int]] = None,
                 is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
21
22
23
24
25
        self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
                                     value=value, sparse_sizes=sparse_sizes,
                                     rowcount=None, colptr=None, colcount=None,
                                     csr2csc=None, csc2csr=None,
                                     is_sorted=is_sorted)
rusty1s's avatar
rusty1s committed
26
27

    @classmethod
rusty1s's avatar
rusty1s committed
28
    def from_storage(self, storage: SparseStorage):
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
        out = SparseTensor(row=storage._row, rowptr=storage._rowptr,
                           col=storage._col, value=storage._value,
                           sparse_sizes=storage._sparse_sizes, is_sorted=True)
        out.storage._rowcount = storage._rowcount
        out.storage._colptr = storage._colptr
        out.storage._colcount = storage._colcount
        out.storage._csr2csc = storage._csr2csc
        out.storage._csc2csr = storage._csc2csr
        return out
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46
47
    @classmethod
    def from_edge_index(self, edge_index: torch.Tensor,
                        edge_attr: Optional[torch.Tensor] = None,
                        sparse_sizes: Optional[Tuple[int, int]] = None,
                        is_sorted: bool = False):
        return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
                            value=edge_attr, sparse_sizes=sparse_sizes,
                            is_sorted=is_sorted)

rusty1s's avatar
rusty1s committed
48
    @classmethod
rusty1s's avatar
rusty1s committed
49
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
rusty1s's avatar
rusty1s committed
50
        if mat.dim() > 2:
rusty1s's avatar
reset  
rusty1s committed
51
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
rusty1s's avatar
rusty1s committed
52
        else:
rusty1s's avatar
reset  
rusty1s committed
53
            index = mat.nonzero()
rusty1s's avatar
rusty1s committed
54
        index = index.t()
rusty1s's avatar
rusty1s committed
55

rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

rusty1s's avatar
rusty1s committed
63
64
65
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
66
67

    @classmethod
rusty1s's avatar
rusty1s committed
68
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
rusty1s's avatar
rusty1s committed
69
                                     has_value: bool = True):
rusty1s's avatar
rusty1s committed
70
71
72
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
rusty1s's avatar
rusty1s committed
73
74
75

        value: Optional[torch.Tensor] = None
        if has_value:
76
            value = mat.values()
rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
rusty1s committed
78
79
80
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
81
82

    @classmethod
rusty1s's avatar
rusty1s committed
83
84
    def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
            dtype: Optional[int] = None, device: Optional[torch.device] = None,
rusty1s's avatar
rusty1s committed
85
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
86

rusty1s's avatar
rusty1s committed
87
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
        row = torch.arange(min(M, N), device=device)
rusty1s's avatar
rusty1s committed
90
        col = row
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
rusty1s committed
92
        rowptr = torch.arange(M + 1, device=row.device)
rusty1s's avatar
rusty1s committed
93
        if M > N:
rusty1s's avatar
rusty1s committed
94
            rowptr[N + 1:] = N
rusty1s's avatar
rusty1s committed
95
96

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
97
        if has_value:
rusty1s's avatar
rusty1s committed
98
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
105
106

        if fill_cache:
rusty1s's avatar
rusty1s committed
107
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
108
            if M > N:
rusty1s's avatar
rusty1s committed
109
110
111
112
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
113
            if N > M:
rusty1s's avatar
rusty1s committed
114
115
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
116
117
            csr2csc = csc2csr = row

rusty1s's avatar
rusty1s committed
118
119
120
121
122
123
124
125
        out = SparseTensor(row=row, rowptr=rowptr, col=col, value=value,
                           sparse_sizes=(M, N), is_sorted=True)
        out.storage._rowcount = rowcount
        out.storage._colptr = colptr
        out.storage._colcount = colcount
        out.storage._csr2csc = csr2csc
        out.storage._csc2csr = csc2csr
        return out
rusty1s's avatar
rusty1s committed
126
127

    def copy(self):
rusty1s's avatar
rusty1s committed
128
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
129
130

    def clone(self):
rusty1s's avatar
rusty1s committed
131
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
132

rusty1s's avatar
typo  
rusty1s committed
133
    def type_as(self, tensor: torch.Tensor):
rusty1s's avatar
rusty1s committed
134
        value = self.storage.value()
rusty1s's avatar
rusty1s committed
135
136
137
138
139
140
141
142
        if value is None or tensor.dtype == value.dtype:
            return self
        return self.from_storage(self.storage.type_as(tensor))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        if tensor.device == self.device():
            return self
        return self.from_storage(self.storage.device_as(tensor, non_blocking))
rusty1s's avatar
rusty1s committed
143
144
145

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
146
147
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
150
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
151

rusty1s's avatar
rusty1s committed
152
153
154
155
156
157
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
158
159
160

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
161
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
162
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
163

rusty1s's avatar
rusty1s committed
164
    def set_value_(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
165
166
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
167
168
        return self

rusty1s's avatar
rusty1s committed
169
    def set_value(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
170
171
172
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

rusty1s's avatar
rusty1s committed
173
    def sparse_sizes(self) -> Tuple[int, int]:
rusty1s's avatar
rusty1s committed
174
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
175

rusty1s's avatar
rusty1s committed
176
177
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
178

rusty1s's avatar
rusty1s committed
179
    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
rusty1s's avatar
rusty1s committed
180
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
181

rusty1s's avatar
rusty1s committed
182
183
184
185
    def sparse_reshape(self, num_rows: int, num_cols: int):
        return self.from_storage(
            self.storage.sparse_reshape(num_rows, num_cols))

rusty1s's avatar
rusty1s committed
186
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
187
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
188

rusty1s's avatar
rusty1s committed
189
    def coalesce(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
190
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
191

rusty1s's avatar
rusty1s committed
192
193
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
194
195
        return self

rusty1s's avatar
rusty1s committed
196
197
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
198
199
        return self

rusty1s's avatar
rusty1s committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    def __eq__(self, other) -> bool:
        if not isinstance(other, self.__class__):
            return False

        if self.sizes() != other.sizes():
            return False

        rowptrA, colA, valueA = self.csr()
        rowptrB, colB, valueB = other.csr()

        if valueA is None and valueB is not None:
            return False
        if valueA is not None and valueB is None:
            return False
        if not torch.equal(rowptrA, rowptrB):
            return False
        if not torch.equal(colA, colB):
            return False
        if valueA is None and valueB is None:
            return True
        return torch.equal(valueA, valueB)

rusty1s's avatar
rusty1s committed
222
223
    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
224
225
226
    def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
227
228
        return self.set_value_(value, layout='coo')

rusty1s's avatar
rusty1s committed
229
230
231
    def fill_value(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
232
233
234
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
rusty1s's avatar
rusty1s committed
235
        sparse_sizes = self.sparse_sizes()
rusty1s's avatar
rusty1s committed
236
237
        value = self.storage.value()
        if value is not None:
rusty1s's avatar
rusty1s committed
238
239
240
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)
rusty1s's avatar
rusty1s committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
257

rusty1s's avatar
rusty1s committed
258
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
259
260
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
261
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
262
263
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
264
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
265
266
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
267
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
268
269
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
270
271
272
273
    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

rusty1s's avatar
rusty1s committed
274
275
276
277
    def avg_bandwidth(self) -> float:
        row, col, _ = self.coo()
        return float((row - col).abs_().to(torch.float).mean())

rusty1s's avatar
rusty1s committed
278
279
280
281
282
    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

rusty1s's avatar
rusty1s committed
283
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
284
285
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
286
287
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
288
289
            return False

rusty1s's avatar
rusty1s committed
290
291
292
293
294
295
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
296
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
297
            return True
rusty1s's avatar
rusty1s committed
298
299
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
300

rusty1s's avatar
rusty1s committed
301
    def to_symmetric(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
302
303
        N = max(self.size(0), self.size(1))

rusty1s's avatar
rusty1s committed
304
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
305
306
307
308
309
310
311
312
313
        idx = col.new_full((2 * col.numel() + 1, ), -1)
        idx[1:row.numel() + 1] = row
        idx[row.numel() + 1:] = col
        idx[1:] *= N
        idx[1:row.numel() + 1] += col
        idx[row.numel() + 1:] += row

        idx, perm = idx.sort()
        mask = idx[1:] > idx[:-1]
rusty1s's avatar
fix  
rusty1s committed
314
315
        perm = perm[1:].sub_(1)
        idx = perm[mask]
rusty1s's avatar
rusty1s committed
316
317

        if value is not None:
rusty1s's avatar
rusty1s committed
318
319
320
321
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
            value = torch.cat([value, value])[perm]
            value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
322

rusty1s's avatar
fix  
rusty1s committed
323
324
        new_row = torch.cat([row, col], dim=0, out=perm)[idx]
        new_col = torch.cat([col, row], dim=0, out=perm)[idx]
rusty1s's avatar
rusty1s committed
325

rusty1s's avatar
rusty1s committed
326
327
        out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value,
                           sparse_sizes=(N, N), is_sorted=True)
rusty1s's avatar
rusty1s committed
328
329
        return out

rusty1s's avatar
rusty1s committed
330
    def detach_(self):
rusty1s's avatar
rusty1s committed
331
332
333
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
334
335
336
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
337
338
339
340
341
342
343
344
345
346
347
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
348

rusty1s's avatar
rusty1s committed
349
    def requires_grad_(self, requires_grad: bool = True,
rusty1s's avatar
rusty1s committed
350
                       dtype: Optional[int] = None):
rusty1s's avatar
rusty1s committed
351
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
352
            self.fill_value_(1., dtype)
rusty1s's avatar
rusty1s committed
353

rusty1s's avatar
rusty1s committed
354
355
356
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
357
358
        return self

rusty1s's avatar
rusty1s committed
359
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
360
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
361

rusty1s's avatar
rusty1s committed
362
363
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
364
365

    def device(self):
rusty1s's avatar
rusty1s committed
366
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
367
368

    def cpu(self):
rusty1s's avatar
rusty1s committed
369
        return self.device_as(torch.tensor(0), non_blocking=False)
rusty1s's avatar
rusty1s committed
370

rusty1s's avatar
rusty1s committed
371
372
    def cuda(self):
        return self.from_storage(self.storage.cuda())
rusty1s's avatar
rusty1s committed
373

rusty1s's avatar
rusty1s committed
374
375
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
376

rusty1s's avatar
rusty1s committed
377
    def dtype(self):
rusty1s's avatar
rusty1s committed
378
379
        value = self.storage.value()
        return value.dtype if value is not None else torch.float
rusty1s's avatar
rusty1s committed
380

rusty1s's avatar
rusty1s committed
381
    def is_floating_point(self) -> bool:
rusty1s's avatar
rusty1s committed
382
383
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True
rusty1s's avatar
rusty1s committed
384
385

    def bfloat16(self):
rusty1s's avatar
rusty1s committed
386
387
        return self.type_as(
            torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
rusty1s's avatar
rusty1s committed
388
389

    def bool(self):
rusty1s's avatar
rusty1s committed
390
391
        return self.type_as(
            torch.tensor(0, dtype=torch.bool, device=self.device()))
rusty1s's avatar
rusty1s committed
392
393

    def byte(self):
rusty1s's avatar
rusty1s committed
394
395
        return self.type_as(
            torch.tensor(0, dtype=torch.uint8, device=self.device()))
rusty1s's avatar
rusty1s committed
396
397

    def char(self):
rusty1s's avatar
rusty1s committed
398
399
        return self.type_as(
            torch.tensor(0, dtype=torch.int8, device=self.device()))
rusty1s's avatar
rusty1s committed
400
401

    def half(self):
rusty1s's avatar
rusty1s committed
402
403
        return self.type_as(
            torch.tensor(0, dtype=torch.half, device=self.device()))
rusty1s's avatar
rusty1s committed
404
405

    def float(self):
rusty1s's avatar
rusty1s committed
406
407
        return self.type_as(
            torch.tensor(0, dtype=torch.float, device=self.device()))
rusty1s's avatar
rusty1s committed
408
409

    def double(self):
rusty1s's avatar
rusty1s committed
410
411
        return self.type_as(
            torch.tensor(0, dtype=torch.double, device=self.device()))
rusty1s's avatar
rusty1s committed
412
413

    def short(self):
rusty1s's avatar
rusty1s committed
414
415
        return self.type_as(
            torch.tensor(0, dtype=torch.short, device=self.device()))
rusty1s's avatar
rusty1s committed
416
417

    def int(self):
rusty1s's avatar
rusty1s committed
418
419
        return self.type_as(
            torch.tensor(0, dtype=torch.int, device=self.device()))
rusty1s's avatar
rusty1s committed
420
421

    def long(self):
rusty1s's avatar
rusty1s committed
422
423
        return self.type_as(
            torch.tensor(0, dtype=torch.long, device=self.device()))
rusty1s's avatar
rusty1s committed
424
425
426

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
427
    def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
428
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
429

rusty1s's avatar
fixes  
rusty1s committed
430
        if value is not None:
rusty1s's avatar
rusty1s committed
431
432
            mat = torch.zeros(self.sizes(), dtype=value.dtype,
                              device=self.device())
rusty1s's avatar
rusty1s committed
433
        else:
rusty1s's avatar
rusty1s committed
434
            mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
435
436
437
438

        if value is not None:
            mat[row, col] = value
        else:
rusty1s's avatar
rusty1s committed
439
440
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)
rusty1s's avatar
rusty1s committed
441

rusty1s's avatar
rusty1s committed
442
443
        return mat

rusty1s's avatar
typo  
rusty1s committed
444
445
    def to_torch_sparse_coo_tensor(
            self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
446
447
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
448

rusty1s's avatar
rusty1s committed
449
        if value is None:
rusty1s's avatar
rusty1s committed
450
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
451

rusty1s's avatar
rusty1s committed
452
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
453

rusty1s's avatar
rusty1s committed
454
455
456
457
458
459

# Python Bindings #############################################################


def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()
rusty1s's avatar
typo  
rusty1s committed
460
    return self
rusty1s's avatar
rusty1s committed
461
462
463
464
465
466


def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


rusty1s's avatar
typing  
rusty1s committed
467
468
469
def to(self, *args: Optional[List[Any]],
       **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:

rusty1s's avatar
rusty1s committed
470
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
rusty1s's avatar
rusty1s committed
471
472
473
474
475
476
477
478
479

    if dtype is not None:
        self = self.type_as(torch.tensor(0., dtype=dtype))
    if device is not None:
        self = self.device_as(torch.tensor(0., device=device), non_blocking)

    return self


rusty1s's avatar
rusty1s committed
480
481
482
483
484
485
486
487
488
def cpu(self) -> SparseTensor:
    return self.device_as(torch.tensor(0., device='cpu'))


def cuda(self, device: Optional[Union[int, str]] = None,
         non_blocking: bool = False):
    return self.device_as(torch.tensor(0., device=device or 'cuda'))


rusty1s's avatar
typing  
rusty1s committed
489
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
rusty1s's avatar
repr  
rusty1s committed
490
491
492
493
494
495
496
497
498
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
    if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
rusty1s's avatar
rusty1s committed
499
500
        if isinstance(item, (list, tuple)):
            item = torch.tensor(item, dtype=torch.long, device=self.device())
rusty1s's avatar
repr  
rusty1s committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


rusty1s's avatar
typing  
rusty1s committed
533
def __repr__(self: SparseTensor) -> str:
rusty1s's avatar
repr  
rusty1s committed
534
535
536
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
rusty1s's avatar
rusty1s committed
537
538
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
539
540

    if value is not None:
rusty1s's avatar
rusty1s committed
541
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
542
543

    infos += [
rusty1s's avatar
rusty1s committed
544
545
        f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
rusty1s's avatar
repr  
rusty1s committed
546
    ]
rusty1s's avatar
rusty1s committed
547

rusty1s's avatar
repr  
rusty1s committed
548
549
550
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
rusty1s's avatar
rusty1s committed
551
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
rusty1s's avatar
repr  
rusty1s committed
552
553


rusty1s's avatar
rusty1s committed
554
555
556
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
rusty1s committed
557
558
SparseTensor.cpu = cpu
SparseTensor.cuda = cuda
rusty1s's avatar
repr  
rusty1s committed
559
560
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
561
562
563

# Scipy Conversions ###########################################################

rusty1s's avatar
typo  
rusty1s committed
564
565
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
                          scipy.sparse.csc_matrix]
rusty1s's avatar
rusty1s committed
566
567
568


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
569
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
rusty1s's avatar
rusty1s committed
570
571
572
573
574
575
576
577
578
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
rusty1s's avatar
rusty1s committed
579
580
581
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
rusty1s's avatar
rusty1s committed
582
583
    sparse_sizes = mat.shape[:2]

rusty1s's avatar
rusty1s committed
584
585
586
587
    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
588
589
590
591
592

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
593
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
rusty1s's avatar
rusty1s committed
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy