tensor.py 19.5 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
typing  
rusty1s committed
2
from typing import Optional, List, Tuple, Dict, Union, Any
rusty1s's avatar
rusty1s committed
3
4
5

import torch
import scipy.sparse
rusty1s's avatar
rusty1s committed
6
from torch_scatter import segment_csr
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
9
10


rusty1s's avatar
rusty1s committed
11
@torch.jit.script
rusty1s's avatar
rusty1s committed
12
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
13
14
    storage: SparseStorage

rusty1s's avatar
rusty1s committed
15
    def __init__(self, row: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
16
17
18
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
19
20
                 sparse_sizes: Optional[Tuple[int, int]] = None,
                 is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
21
22
23
24
25
        self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
                                     value=value, sparse_sizes=sparse_sizes,
                                     rowcount=None, colptr=None, colcount=None,
                                     csr2csc=None, csc2csr=None,
                                     is_sorted=is_sorted)
rusty1s's avatar
rusty1s committed
26
27

    @classmethod
rusty1s's avatar
rusty1s committed
28
    def from_storage(self, storage: SparseStorage):
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
        out = SparseTensor(row=storage._row, rowptr=storage._rowptr,
                           col=storage._col, value=storage._value,
                           sparse_sizes=storage._sparse_sizes, is_sorted=True)
        out.storage._rowcount = storage._rowcount
        out.storage._colptr = storage._colptr
        out.storage._colcount = storage._colcount
        out.storage._csr2csc = storage._csr2csc
        out.storage._csc2csr = storage._csc2csr
        return out
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46
47
    @classmethod
    def from_edge_index(self, edge_index: torch.Tensor,
                        edge_attr: Optional[torch.Tensor] = None,
                        sparse_sizes: Optional[Tuple[int, int]] = None,
                        is_sorted: bool = False):
        return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
                            value=edge_attr, sparse_sizes=sparse_sizes,
                            is_sorted=is_sorted)

rusty1s's avatar
rusty1s committed
48
    @classmethod
rusty1s's avatar
rusty1s committed
49
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
rusty1s's avatar
rusty1s committed
50
        if mat.dim() > 2:
rusty1s's avatar
reset  
rusty1s committed
51
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
rusty1s's avatar
rusty1s committed
52
        else:
rusty1s's avatar
reset  
rusty1s committed
53
            index = mat.nonzero()
rusty1s's avatar
rusty1s committed
54
        index = index.t()
rusty1s's avatar
rusty1s committed
55

rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

rusty1s's avatar
rusty1s committed
63
64
65
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
66
67

    @classmethod
rusty1s's avatar
rusty1s committed
68
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
rusty1s's avatar
rusty1s committed
69
                                     has_value: bool = True):
rusty1s's avatar
rusty1s committed
70
71
72
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
rusty1s's avatar
rusty1s committed
73
74
75

        value: Optional[torch.Tensor] = None
        if has_value:
76
            value = mat.values()
rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
rusty1s committed
78
79
80
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
81
82

    @classmethod
rusty1s's avatar
rusty1s committed
83
84
    def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
            dtype: Optional[int] = None, device: Optional[torch.device] = None,
rusty1s's avatar
rusty1s committed
85
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
86

rusty1s's avatar
rusty1s committed
87
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
        row = torch.arange(min(M, N), device=device)
rusty1s's avatar
rusty1s committed
90
        col = row
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
rusty1s committed
92
        rowptr = torch.arange(M + 1, device=row.device)
rusty1s's avatar
rusty1s committed
93
        if M > N:
rusty1s's avatar
rusty1s committed
94
            rowptr[N + 1:] = N
rusty1s's avatar
rusty1s committed
95
96

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
97
        if has_value:
rusty1s's avatar
rusty1s committed
98
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
105
106

        if fill_cache:
rusty1s's avatar
rusty1s committed
107
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
108
            if M > N:
rusty1s's avatar
rusty1s committed
109
110
111
112
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
113
            if N > M:
rusty1s's avatar
rusty1s committed
114
115
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
116
117
            csr2csc = csc2csr = row

rusty1s's avatar
rusty1s committed
118
119
120
121
122
123
124
125
        out = SparseTensor(row=row, rowptr=rowptr, col=col, value=value,
                           sparse_sizes=(M, N), is_sorted=True)
        out.storage._rowcount = rowcount
        out.storage._colptr = colptr
        out.storage._colcount = colcount
        out.storage._csr2csc = csr2csc
        out.storage._csc2csr = csc2csr
        return out
rusty1s's avatar
rusty1s committed
126
127

    def copy(self):
rusty1s's avatar
rusty1s committed
128
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
129
130

    def clone(self):
rusty1s's avatar
rusty1s committed
131
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
132

rusty1s's avatar
typo  
rusty1s committed
133
    def type_as(self, tensor: torch.Tensor):
rusty1s's avatar
rusty1s committed
134
        value = self.storage.value()
rusty1s's avatar
rusty1s committed
135
136
137
138
139
140
141
142
        if value is None or tensor.dtype == value.dtype:
            return self
        return self.from_storage(self.storage.type_as(tensor))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        if tensor.device == self.device():
            return self
        return self.from_storage(self.storage.device_as(tensor, non_blocking))
rusty1s's avatar
rusty1s committed
143
144
145

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
146
147
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
150
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
151

rusty1s's avatar
rusty1s committed
152
153
154
155
156
157
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
158
159
160

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
161
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
162
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
163

rusty1s's avatar
rusty1s committed
164
    def set_value_(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
165
166
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
167
168
        return self

rusty1s's avatar
rusty1s committed
169
    def set_value(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
170
171
172
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

rusty1s's avatar
rusty1s committed
173
    def sparse_sizes(self) -> Tuple[int, int]:
rusty1s's avatar
rusty1s committed
174
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
175

rusty1s's avatar
rusty1s committed
176
177
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
178

rusty1s's avatar
rusty1s committed
179
    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
rusty1s's avatar
rusty1s committed
180
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
181

rusty1s's avatar
rusty1s committed
182
183
184
185
    def sparse_reshape(self, num_rows: int, num_cols: int):
        return self.from_storage(
            self.storage.sparse_reshape(num_rows, num_cols))

rusty1s's avatar
rusty1s committed
186
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
187
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
188

rusty1s's avatar
rusty1s committed
189
    def coalesce(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
190
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
191

rusty1s's avatar
rusty1s committed
192
193
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
194
195
        return self

rusty1s's avatar
rusty1s committed
196
197
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
198
199
200
201
        return self

    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
202
203
204
    def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
205
206
        return self.set_value_(value, layout='coo')

rusty1s's avatar
rusty1s committed
207
208
209
    def fill_value(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
210
211
212
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
rusty1s's avatar
rusty1s committed
213
        sparse_sizes = self.sparse_sizes()
rusty1s's avatar
rusty1s committed
214
215
        value = self.storage.value()
        if value is not None:
rusty1s's avatar
rusty1s committed
216
217
218
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)
rusty1s's avatar
rusty1s committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
235

rusty1s's avatar
rusty1s committed
236
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
237
238
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
239
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
240
241
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
242
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
243
244
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
245
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
246
247
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
248
249
250
251
    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

rusty1s's avatar
rusty1s committed
252
253
254
255
    def avg_bandwidth(self) -> float:
        row, col, _ = self.coo()
        return float((row - col).abs_().to(torch.float).mean())

rusty1s's avatar
rusty1s committed
256
257
258
259
260
    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

rusty1s's avatar
rusty1s committed
261
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
262
263
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
264
265
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
266
267
            return False

rusty1s's avatar
rusty1s committed
268
269
270
271
272
273
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
274
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
275
            return True
rusty1s's avatar
rusty1s committed
276
277
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
278

rusty1s's avatar
rusty1s committed
279
    def to_symmetric(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
280
281
        N = max(self.size(0), self.size(1))

rusty1s's avatar
rusty1s committed
282
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
283
284
285
286
287
288
289
290
291
        idx = col.new_full((2 * col.numel() + 1, ), -1)
        idx[1:row.numel() + 1] = row
        idx[row.numel() + 1:] = col
        idx[1:] *= N
        idx[1:row.numel() + 1] += col
        idx[row.numel() + 1:] += row

        idx, perm = idx.sort()
        mask = idx[1:] > idx[:-1]
rusty1s's avatar
fix  
rusty1s committed
292
293
        perm = perm[1:].sub_(1)
        idx = perm[mask]
rusty1s's avatar
rusty1s committed
294
295

        if value is not None:
rusty1s's avatar
rusty1s committed
296
297
298
299
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
            value = torch.cat([value, value])[perm]
            value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
300

rusty1s's avatar
fix  
rusty1s committed
301
302
        new_row = torch.cat([row, col], dim=0, out=perm)[idx]
        new_col = torch.cat([col, row], dim=0, out=perm)[idx]
rusty1s's avatar
rusty1s committed
303

rusty1s's avatar
rusty1s committed
304
305
        out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value,
                           sparse_sizes=(N, N), is_sorted=True)
rusty1s's avatar
rusty1s committed
306
307
        return out

rusty1s's avatar
rusty1s committed
308
    def detach_(self):
rusty1s's avatar
rusty1s committed
309
310
311
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
312
313
314
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
315
316
317
318
319
320
321
322
323
324
325
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
326

rusty1s's avatar
rusty1s committed
327
    def requires_grad_(self, requires_grad: bool = True,
rusty1s's avatar
rusty1s committed
328
                       dtype: Optional[int] = None):
rusty1s's avatar
rusty1s committed
329
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
330
            self.fill_value_(1., dtype)
rusty1s's avatar
rusty1s committed
331

rusty1s's avatar
rusty1s committed
332
333
334
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
335
336
        return self

rusty1s's avatar
rusty1s committed
337
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
338
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
339

rusty1s's avatar
rusty1s committed
340
341
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
342
343

    def device(self):
rusty1s's avatar
rusty1s committed
344
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
345
346

    def cpu(self):
rusty1s's avatar
rusty1s committed
347
        return self.device_as(torch.tensor(0), non_blocking=False)
rusty1s's avatar
rusty1s committed
348

rusty1s's avatar
rusty1s committed
349
350
    def cuda(self):
        return self.from_storage(self.storage.cuda())
rusty1s's avatar
rusty1s committed
351

rusty1s's avatar
rusty1s committed
352
353
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
354

rusty1s's avatar
rusty1s committed
355
    def dtype(self):
rusty1s's avatar
rusty1s committed
356
357
        value = self.storage.value()
        return value.dtype if value is not None else torch.float
rusty1s's avatar
rusty1s committed
358

rusty1s's avatar
rusty1s committed
359
    def is_floating_point(self) -> bool:
rusty1s's avatar
rusty1s committed
360
361
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True
rusty1s's avatar
rusty1s committed
362
363

    def bfloat16(self):
rusty1s's avatar
rusty1s committed
364
365
        return self.type_as(
            torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
rusty1s's avatar
rusty1s committed
366
367

    def bool(self):
rusty1s's avatar
rusty1s committed
368
369
        return self.type_as(
            torch.tensor(0, dtype=torch.bool, device=self.device()))
rusty1s's avatar
rusty1s committed
370
371

    def byte(self):
rusty1s's avatar
rusty1s committed
372
373
        return self.type_as(
            torch.tensor(0, dtype=torch.uint8, device=self.device()))
rusty1s's avatar
rusty1s committed
374
375

    def char(self):
rusty1s's avatar
rusty1s committed
376
377
        return self.type_as(
            torch.tensor(0, dtype=torch.int8, device=self.device()))
rusty1s's avatar
rusty1s committed
378
379

    def half(self):
rusty1s's avatar
rusty1s committed
380
381
        return self.type_as(
            torch.tensor(0, dtype=torch.half, device=self.device()))
rusty1s's avatar
rusty1s committed
382
383

    def float(self):
rusty1s's avatar
rusty1s committed
384
385
        return self.type_as(
            torch.tensor(0, dtype=torch.float, device=self.device()))
rusty1s's avatar
rusty1s committed
386
387

    def double(self):
rusty1s's avatar
rusty1s committed
388
389
        return self.type_as(
            torch.tensor(0, dtype=torch.double, device=self.device()))
rusty1s's avatar
rusty1s committed
390
391

    def short(self):
rusty1s's avatar
rusty1s committed
392
393
        return self.type_as(
            torch.tensor(0, dtype=torch.short, device=self.device()))
rusty1s's avatar
rusty1s committed
394
395

    def int(self):
rusty1s's avatar
rusty1s committed
396
397
        return self.type_as(
            torch.tensor(0, dtype=torch.int, device=self.device()))
rusty1s's avatar
rusty1s committed
398
399

    def long(self):
rusty1s's avatar
rusty1s committed
400
401
        return self.type_as(
            torch.tensor(0, dtype=torch.long, device=self.device()))
rusty1s's avatar
rusty1s committed
402
403
404

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
405
    def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
406
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
407

rusty1s's avatar
fixes  
rusty1s committed
408
        if value is not None:
rusty1s's avatar
rusty1s committed
409
410
            mat = torch.zeros(self.sizes(), dtype=value.dtype,
                              device=self.device())
rusty1s's avatar
rusty1s committed
411
        else:
rusty1s's avatar
rusty1s committed
412
            mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
413
414
415
416

        if value is not None:
            mat[row, col] = value
        else:
rusty1s's avatar
rusty1s committed
417
418
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)
rusty1s's avatar
rusty1s committed
419

rusty1s's avatar
rusty1s committed
420
421
        return mat

rusty1s's avatar
typo  
rusty1s committed
422
423
    def to_torch_sparse_coo_tensor(
            self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
424
425
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
426

rusty1s's avatar
rusty1s committed
427
        if value is None:
rusty1s's avatar
rusty1s committed
428
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
429

rusty1s's avatar
rusty1s committed
430
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
431

rusty1s's avatar
rusty1s committed
432
433
434
435
436
437
438
439
440
441
442
443

# Python Bindings #############################################################


def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()


def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


rusty1s's avatar
typing  
rusty1s committed
444
445
446
def to(self, *args: Optional[List[Any]],
       **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:

rusty1s's avatar
rusty1s committed
447
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
rusty1s's avatar
rusty1s committed
448
449
450
451
452
453
454
455
456

    if dtype is not None:
        self = self.type_as(torch.tensor(0., dtype=dtype))
    if device is not None:
        self = self.device_as(torch.tensor(0., device=device), non_blocking)

    return self


rusty1s's avatar
typing  
rusty1s committed
457
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
rusty1s's avatar
repr  
rusty1s committed
458
459
460
461
462
463
464
465
466
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
    if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
rusty1s's avatar
rusty1s committed
467
468
        if isinstance(item, (list, tuple)):
            item = torch.tensor(item, dtype=torch.long, device=self.device())
rusty1s's avatar
repr  
rusty1s committed
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


rusty1s's avatar
typing  
rusty1s committed
501
def __repr__(self: SparseTensor) -> str:
rusty1s's avatar
repr  
rusty1s committed
502
503
504
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
rusty1s's avatar
rusty1s committed
505
506
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
507
508

    if value is not None:
rusty1s's avatar
rusty1s committed
509
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
510
511

    infos += [
rusty1s's avatar
rusty1s committed
512
513
        f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
rusty1s's avatar
repr  
rusty1s committed
514
    ]
rusty1s's avatar
rusty1s committed
515

rusty1s's avatar
repr  
rusty1s committed
516
517
518
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
rusty1s's avatar
rusty1s committed
519
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
rusty1s's avatar
repr  
rusty1s committed
520
521


rusty1s's avatar
rusty1s committed
522
523
524
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
repr  
rusty1s committed
525
526
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
527
528
529

# Scipy Conversions ###########################################################

rusty1s's avatar
typo  
rusty1s committed
530
531
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
                          scipy.sparse.csc_matrix]
rusty1s's avatar
rusty1s committed
532
533
534


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
535
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
rusty1s's avatar
rusty1s committed
536
537
538
539
540
541
542
543
544
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
rusty1s's avatar
rusty1s committed
545
546
547
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
rusty1s's avatar
rusty1s committed
548
549
    sparse_sizes = mat.shape[:2]

rusty1s's avatar
rusty1s committed
550
551
552
553
    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
554
555
556
557
558

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
559
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
rusty1s's avatar
rusty1s committed
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy