tensor.py 19.1 KB
Newer Older
rusty1s's avatar
repr  
rusty1s committed
1
from textwrap import indent
rusty1s's avatar
typing  
rusty1s committed
2
from typing import Optional, List, Tuple, Dict, Union, Any
rusty1s's avatar
rusty1s committed
3
4
5

import torch
import scipy.sparse
rusty1s's avatar
rusty1s committed
6
from torch_scatter import segment_csr
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
from torch_sparse.storage import SparseStorage, get_layout
rusty1s's avatar
rusty1s committed
9
10


rusty1s's avatar
rusty1s committed
11
@torch.jit.script
rusty1s's avatar
rusty1s committed
12
class SparseTensor(object):
rusty1s's avatar
rusty1s committed
13
14
    storage: SparseStorage

rusty1s's avatar
rusty1s committed
15
    def __init__(self, row: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
16
17
18
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
rusty1s's avatar
rusty1s committed
19
20
                 sparse_sizes: Optional[Tuple[int, int]] = None,
                 is_sorted: bool = False):
rusty1s's avatar
rusty1s committed
21
22
23
24
25
        self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
                                     value=value, sparse_sizes=sparse_sizes,
                                     rowcount=None, colptr=None, colcount=None,
                                     csr2csc=None, csc2csr=None,
                                     is_sorted=is_sorted)
rusty1s's avatar
rusty1s committed
26
27

    @classmethod
rusty1s's avatar
rusty1s committed
28
    def from_storage(self, storage: SparseStorage):
rusty1s's avatar
rusty1s committed
29
        self = SparseTensor.__new__(SparseTensor)
rusty1s's avatar
rusty1s committed
30
        self.storage = storage
rusty1s's avatar
rusty1s committed
31
32
        return self

rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
    @classmethod
    def from_edge_index(self, edge_index: torch.Tensor,
                        edge_attr: Optional[torch.Tensor] = None,
                        sparse_sizes: Optional[Tuple[int, int]] = None,
                        is_sorted: bool = False):
        return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
                            value=edge_attr, sparse_sizes=sparse_sizes,
                            is_sorted=is_sorted)

rusty1s's avatar
rusty1s committed
42
    @classmethod
rusty1s's avatar
rusty1s committed
43
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
rusty1s's avatar
rusty1s committed
44
        if mat.dim() > 2:
rusty1s's avatar
reset  
rusty1s committed
45
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
rusty1s's avatar
rusty1s committed
46
        else:
rusty1s's avatar
reset  
rusty1s committed
47
            index = mat.nonzero()
rusty1s's avatar
rusty1s committed
48
        index = index.t()
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

rusty1s's avatar
rusty1s committed
57
58
59
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
60
61

    @classmethod
rusty1s's avatar
rusty1s committed
62
    def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
rusty1s's avatar
rusty1s committed
63
                                     has_value: bool = True):
rusty1s's avatar
rusty1s committed
64
65
66
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]
rusty1s's avatar
rusty1s committed
67
68
69

        value: Optional[torch.Tensor] = None
        if has_value:
70
            value = mat.values()
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72
73
74
        return SparseTensor(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)
rusty1s's avatar
rusty1s committed
75
76

    @classmethod
rusty1s's avatar
rusty1s committed
77
78
    def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
            dtype: Optional[int] = None, device: Optional[torch.device] = None,
rusty1s's avatar
rusty1s committed
79
            fill_cache: bool = False):
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
        N = M if N is None else N
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
        row = torch.arange(min(M, N), device=device)
rusty1s's avatar
rusty1s committed
84
        col = row
rusty1s's avatar
rusty1s committed
85

rusty1s's avatar
rusty1s committed
86
        rowptr = torch.arange(M + 1, device=row.device)
rusty1s's avatar
rusty1s committed
87
        if M > N:
rusty1s's avatar
rusty1s committed
88
            rowptr[N + 1:] = N
rusty1s's avatar
rusty1s committed
89
90

        value: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
91
        if has_value:
rusty1s's avatar
rusty1s committed
92
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)
rusty1s's avatar
rusty1s committed
93
94
95
96
97
98

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None
rusty1s's avatar
rusty1s committed
99
100

        if fill_cache:
rusty1s's avatar
rusty1s committed
101
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
102
            if M > N:
rusty1s's avatar
rusty1s committed
103
104
105
106
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
107
            if N > M:
rusty1s's avatar
rusty1s committed
108
109
                colptr[M + 1:] = M
                colcount[M:] = 0
rusty1s's avatar
rusty1s committed
110
111
            csr2csc = csc2csr = row

rusty1s's avatar
rusty1s committed
112
        storage: SparseStorage = SparseStorage(
rusty1s's avatar
rusty1s committed
113
114
115
            row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N),
            rowcount=rowcount, colptr=colptr, colcount=colcount,
            csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
rusty1s committed
117
118
119
120
121
        self = SparseTensor.__new__(SparseTensor)
        self.storage = storage
        return self

    def copy(self):
rusty1s's avatar
rusty1s committed
122
        return self.from_storage(self.storage)
rusty1s's avatar
rusty1s committed
123
124

    def clone(self):
rusty1s's avatar
rusty1s committed
125
        return self.from_storage(self.storage.clone())
rusty1s's avatar
rusty1s committed
126

rusty1s's avatar
typo  
rusty1s committed
127
    def type_as(self, tensor: torch.Tensor):
rusty1s's avatar
rusty1s committed
128
        value = self.storage.value()
rusty1s's avatar
rusty1s committed
129
130
131
132
133
134
135
136
        if value is None or tensor.dtype == value.dtype:
            return self
        return self.from_storage(self.storage.type_as(tensor))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        if tensor.device == self.device():
            return self
        return self.from_storage(self.storage.device_as(tensor, non_blocking))
rusty1s's avatar
rusty1s committed
137
138
139

    # Formats #################################################################

rusty1s's avatar
rusty1s committed
140
141
    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
142

rusty1s's avatar
rusty1s committed
143
144
    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()
rusty1s's avatar
rusty1s committed
145

rusty1s's avatar
rusty1s committed
146
147
148
149
150
151
    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value
rusty1s's avatar
rusty1s committed
152
153
154

    # Storage inheritance #####################################################

rusty1s's avatar
rusty1s committed
155
    def has_value(self) -> bool:
rusty1s's avatar
rusty1s committed
156
        return self.storage.has_value()
rusty1s's avatar
rusty1s committed
157

rusty1s's avatar
rusty1s committed
158
    def set_value_(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
159
160
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
rusty1s's avatar
rusty1s committed
161
162
        return self

rusty1s's avatar
rusty1s committed
163
    def set_value(self, value: Optional[torch.Tensor],
rusty1s's avatar
rusty1s committed
164
165
166
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

rusty1s's avatar
rusty1s committed
167
    def sparse_sizes(self) -> Tuple[int, int]:
rusty1s's avatar
rusty1s committed
168
        return self.storage.sparse_sizes()
rusty1s's avatar
rusty1s committed
169

rusty1s's avatar
rusty1s committed
170
171
    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]
rusty1s's avatar
rusty1s committed
172

rusty1s's avatar
rusty1s committed
173
    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
rusty1s's avatar
rusty1s committed
174
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))
rusty1s's avatar
rusty1s committed
175

rusty1s's avatar
rusty1s committed
176
177
178
179
    def sparse_reshape(self, num_rows: int, num_cols: int):
        return self.from_storage(
            self.storage.sparse_reshape(num_rows, num_cols))

rusty1s's avatar
rusty1s committed
180
    def is_coalesced(self) -> bool:
rusty1s's avatar
rusty1s committed
181
        return self.storage.is_coalesced()
rusty1s's avatar
rusty1s committed
182

rusty1s's avatar
rusty1s committed
183
    def coalesce(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
184
        return self.from_storage(self.storage.coalesce(reduce))
rusty1s's avatar
rusty1s committed
185

rusty1s's avatar
rusty1s committed
186
187
    def fill_cache_(self):
        self.storage.fill_cache_()
rusty1s's avatar
rusty1s committed
188
189
        return self

rusty1s's avatar
rusty1s committed
190
191
    def clear_cache_(self):
        self.storage.clear_cache_()
rusty1s's avatar
rusty1s committed
192
193
194
195
        return self

    # Utility functions #######################################################

rusty1s's avatar
rusty1s committed
196
197
198
    def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
199
200
        return self.set_value_(value, layout='coo')

rusty1s's avatar
rusty1s committed
201
202
203
    def fill_value(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
                           device=self.device())
rusty1s's avatar
rusty1s committed
204
205
206
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
rusty1s's avatar
rusty1s committed
207
        sparse_sizes = self.sparse_sizes()
rusty1s's avatar
rusty1s committed
208
209
        value = self.storage.value()
        if value is not None:
rusty1s's avatar
rusty1s committed
210
211
212
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)
rusty1s's avatar
rusty1s committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()
rusty1s's avatar
rusty1s committed
229

rusty1s's avatar
rusty1s committed
230
    def density(self) -> float:
rusty1s's avatar
rusty1s committed
231
232
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

rusty1s's avatar
rusty1s committed
233
    def sparsity(self) -> float:
rusty1s's avatar
rusty1s committed
234
235
        return 1 - self.density()

rusty1s's avatar
rusty1s committed
236
    def avg_row_length(self) -> float:
rusty1s's avatar
rusty1s committed
237
238
        return self.nnz() / self.sparse_size(0)

rusty1s's avatar
rusty1s committed
239
    def avg_col_length(self) -> float:
rusty1s's avatar
rusty1s committed
240
241
        return self.nnz() / self.sparse_size(1)

rusty1s's avatar
rusty1s committed
242
243
244
245
    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

rusty1s's avatar
rusty1s committed
246
247
248
249
    def avg_bandwidth(self) -> float:
        row, col, _ = self.coo()
        return float((row - col).abs_().to(torch.float).mean())

rusty1s's avatar
rusty1s committed
250
251
252
253
254
    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

rusty1s's avatar
rusty1s committed
255
    def is_quadratic(self) -> bool:
rusty1s's avatar
rusty1s committed
256
257
        return self.sparse_size(0) == self.sparse_size(1)

rusty1s's avatar
rusty1s committed
258
259
    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
rusty1s's avatar
rusty1s committed
260
261
            return False

rusty1s's avatar
rusty1s committed
262
263
264
265
266
267
        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

rusty1s's avatar
rusty1s committed
268
        if value1 is None or value2 is None:
rusty1s's avatar
rusty1s committed
269
            return True
rusty1s's avatar
rusty1s committed
270
271
        else:
            return bool((value1 == value2).all())
rusty1s's avatar
rusty1s committed
272

rusty1s's avatar
rusty1s committed
273
    def to_symmetric(self, reduce: str = "sum"):
rusty1s's avatar
rusty1s committed
274
275
        N = max(self.size(0), self.size(1))

rusty1s's avatar
rusty1s committed
276
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
277
278
279
280
281
282
283
284
285
286
287
288
        idx = col.new_full((2 * col.numel() + 1, ), -1)
        idx[1:row.numel() + 1] = row
        idx[row.numel() + 1:] = col
        idx[1:] *= N
        idx[1:row.numel() + 1] += col
        idx[row.numel() + 1:] += row

        idx, perm = idx.sort()
        perm = perm[1:].sub_(1)

        mask = idx[1:] > idx[:-1]
        idx2 = perm[mask]
rusty1s's avatar
rusty1s committed
289
290

        if value is not None:
rusty1s's avatar
rusty1s committed
291
292
293
294
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
            value = torch.cat([value, value])[perm]
            value = segment_csr(value, ptr, reduce=reduce)
rusty1s's avatar
rusty1s committed
295

rusty1s's avatar
rusty1s committed
296
297
        new_row = torch.cat([row, col], dim=0, out=perm)[idx2]
        new_col = torch.cat([col, row], dim=0, out=perm)[idx2]
rusty1s's avatar
rusty1s committed
298

rusty1s's avatar
rusty1s committed
299
300
        out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value,
                           sparse_sizes=(N, N), is_sorted=True)
rusty1s's avatar
rusty1s committed
301
302
        return out

rusty1s's avatar
rusty1s committed
303
    def detach_(self):
rusty1s's avatar
rusty1s committed
304
305
306
        value = self.storage.value()
        if value is not None:
            value.detach_()
rusty1s's avatar
rusty1s committed
307
308
309
        return self

    def detach(self):
rusty1s's avatar
rusty1s committed
310
311
312
313
314
315
316
317
318
319
320
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False
rusty1s's avatar
rusty1s committed
321

rusty1s's avatar
rusty1s committed
322
    def requires_grad_(self, requires_grad: bool = True,
rusty1s's avatar
rusty1s committed
323
                       dtype: Optional[int] = None):
rusty1s's avatar
rusty1s committed
324
        if requires_grad and not self.has_value():
rusty1s's avatar
rusty1s committed
325
            self.fill_value_(1., dtype)
rusty1s's avatar
rusty1s committed
326

rusty1s's avatar
rusty1s committed
327
328
329
        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
rusty1s's avatar
rusty1s committed
330
331
        return self

rusty1s's avatar
rusty1s committed
332
    def pin_memory(self):
rusty1s's avatar
rusty1s committed
333
        return self.from_storage(self.storage.pin_memory())
rusty1s's avatar
rusty1s committed
334

rusty1s's avatar
rusty1s committed
335
336
    def is_pinned(self) -> bool:
        return self.storage.is_pinned()
rusty1s's avatar
rusty1s committed
337
338

    def device(self):
rusty1s's avatar
rusty1s committed
339
        return self.storage.col().device
rusty1s's avatar
rusty1s committed
340
341

    def cpu(self):
rusty1s's avatar
rusty1s committed
342
        return self.device_as(torch.tensor(0), non_blocking=False)
rusty1s's avatar
rusty1s committed
343

rusty1s's avatar
rusty1s committed
344
345
    def cuda(self):
        return self.from_storage(self.storage.cuda())
rusty1s's avatar
rusty1s committed
346

rusty1s's avatar
rusty1s committed
347
348
    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda
rusty1s's avatar
rusty1s committed
349

rusty1s's avatar
rusty1s committed
350
    def dtype(self):
rusty1s's avatar
rusty1s committed
351
352
        value = self.storage.value()
        return value.dtype if value is not None else torch.float
rusty1s's avatar
rusty1s committed
353

rusty1s's avatar
rusty1s committed
354
    def is_floating_point(self) -> bool:
rusty1s's avatar
rusty1s committed
355
356
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True
rusty1s's avatar
rusty1s committed
357
358

    def bfloat16(self):
rusty1s's avatar
rusty1s committed
359
360
        return self.type_as(
            torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
rusty1s's avatar
rusty1s committed
361
362

    def bool(self):
rusty1s's avatar
rusty1s committed
363
364
        return self.type_as(
            torch.tensor(0, dtype=torch.bool, device=self.device()))
rusty1s's avatar
rusty1s committed
365
366

    def byte(self):
rusty1s's avatar
rusty1s committed
367
368
        return self.type_as(
            torch.tensor(0, dtype=torch.uint8, device=self.device()))
rusty1s's avatar
rusty1s committed
369
370

    def char(self):
rusty1s's avatar
rusty1s committed
371
372
        return self.type_as(
            torch.tensor(0, dtype=torch.int8, device=self.device()))
rusty1s's avatar
rusty1s committed
373
374

    def half(self):
rusty1s's avatar
rusty1s committed
375
376
        return self.type_as(
            torch.tensor(0, dtype=torch.half, device=self.device()))
rusty1s's avatar
rusty1s committed
377
378

    def float(self):
rusty1s's avatar
rusty1s committed
379
380
        return self.type_as(
            torch.tensor(0, dtype=torch.float, device=self.device()))
rusty1s's avatar
rusty1s committed
381
382

    def double(self):
rusty1s's avatar
rusty1s committed
383
384
        return self.type_as(
            torch.tensor(0, dtype=torch.double, device=self.device()))
rusty1s's avatar
rusty1s committed
385
386

    def short(self):
rusty1s's avatar
rusty1s committed
387
388
        return self.type_as(
            torch.tensor(0, dtype=torch.short, device=self.device()))
rusty1s's avatar
rusty1s committed
389
390

    def int(self):
rusty1s's avatar
rusty1s committed
391
392
        return self.type_as(
            torch.tensor(0, dtype=torch.int, device=self.device()))
rusty1s's avatar
rusty1s committed
393
394

    def long(self):
rusty1s's avatar
rusty1s committed
395
396
        return self.type_as(
            torch.tensor(0, dtype=torch.long, device=self.device()))
rusty1s's avatar
rusty1s committed
397
398
399

    # Conversions #############################################################

rusty1s's avatar
rusty1s committed
400
    def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
401
        row, col, value = self.coo()
rusty1s's avatar
rusty1s committed
402

rusty1s's avatar
fixes  
rusty1s committed
403
        if value is not None:
rusty1s's avatar
rusty1s committed
404
405
            mat = torch.zeros(self.sizes(), dtype=value.dtype,
                              device=self.device())
rusty1s's avatar
rusty1s committed
406
        else:
rusty1s's avatar
rusty1s committed
407
            mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
408
409
410
411

        if value is not None:
            mat[row, col] = value
        else:
rusty1s's avatar
rusty1s committed
412
413
            mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
                                       device=mat.device)
rusty1s's avatar
rusty1s committed
414

rusty1s's avatar
rusty1s committed
415
416
        return mat

rusty1s's avatar
typo  
rusty1s committed
417
418
    def to_torch_sparse_coo_tensor(
            self, dtype: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
419
420
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
421

rusty1s's avatar
rusty1s committed
422
        if value is None:
rusty1s's avatar
rusty1s committed
423
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
rusty1s's avatar
rusty1s committed
424

rusty1s's avatar
rusty1s committed
425
        return torch.sparse_coo_tensor(index, value, self.sizes())
rusty1s's avatar
rusty1s committed
426

rusty1s's avatar
rusty1s committed
427
428
429
430
431
432
433
434
435
436
437
438

# Python Bindings #############################################################


def share_memory_(self: SparseTensor) -> SparseTensor:
    self.storage.share_memory_()


def is_shared(self: SparseTensor) -> bool:
    return self.storage.is_shared()


rusty1s's avatar
typing  
rusty1s committed
439
440
441
def to(self, *args: Optional[List[Any]],
       **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:

rusty1s's avatar
rusty1s committed
442
    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
rusty1s's avatar
rusty1s committed
443
444
445
446
447
448
449
450
451

    if dtype is not None:
        self = self.type_as(torch.tensor(0., dtype=dtype))
    if device is not None:
        self = self.device_as(torch.tensor(0., device=device), non_blocking)

    return self


rusty1s's avatar
typing  
rusty1s committed
452
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
rusty1s's avatar
repr  
rusty1s committed
453
454
455
456
457
458
459
460
461
    index = list(index) if isinstance(index, tuple) else [index]
    # More than one `Ellipsis` is not allowed...
    if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
        raise SyntaxError

    dim = 0
    out = self
    while len(index) > 0:
        item = index.pop(0)
rusty1s's avatar
rusty1s committed
462
463
        if isinstance(item, (list, tuple)):
            item = torch.tensor(item, dtype=torch.long, device=self.device())
rusty1s's avatar
repr  
rusty1s committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        if isinstance(item, int):
            out = out.select(dim, item)
            dim += 1
        elif isinstance(item, slice):
            if item.step is not None:
                raise ValueError('Step parameter not yet supported.')

            start = 0 if item.start is None else item.start
            start = self.size(dim) + start if start < 0 else start

            stop = self.size(dim) if item.stop is None else item.stop
            stop = self.size(dim) + stop if stop < 0 else stop

            out = out.narrow(dim, start, max(stop - start, 0))
            dim += 1
        elif torch.is_tensor(item):
            if item.dtype == torch.bool:
                out = out.masked_select(dim, item)
                dim += 1
            elif item.dtype == torch.long:
                out = out.index_select(dim, item)
                dim += 1
        elif item == Ellipsis:
            if self.dim() - len(index) < dim:
                raise SyntaxError
            dim = self.dim() - len(index)
        else:
            raise SyntaxError

    return out


rusty1s's avatar
typing  
rusty1s committed
496
def __repr__(self: SparseTensor) -> str:
rusty1s's avatar
repr  
rusty1s committed
497
498
499
    i = ' ' * 6
    row, col, value = self.coo()
    infos = []
rusty1s's avatar
rusty1s committed
500
501
    infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
    infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
502
503

    if value is not None:
rusty1s's avatar
rusty1s committed
504
        infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
rusty1s's avatar
repr  
rusty1s committed
505
506

    infos += [
rusty1s's avatar
rusty1s committed
507
508
        f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
        f'density={100 * self.density():.02f}%'
rusty1s's avatar
repr  
rusty1s committed
509
    ]
rusty1s's avatar
rusty1s committed
510

rusty1s's avatar
repr  
rusty1s committed
511
512
513
    infos = ',\n'.join(infos)

    i = ' ' * (len(self.__class__.__name__) + 1)
rusty1s's avatar
rusty1s committed
514
    return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
rusty1s's avatar
repr  
rusty1s committed
515
516


rusty1s's avatar
rusty1s committed
517
518
519
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
rusty1s's avatar
repr  
rusty1s committed
520
521
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
rusty1s's avatar
rusty1s committed
522
523
524

# Scipy Conversions ###########################################################

rusty1s's avatar
typo  
rusty1s committed
525
526
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
                          scipy.sparse.csc_matrix]
rusty1s's avatar
rusty1s committed
527
528
529


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
530
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
rusty1s's avatar
rusty1s committed
531
532
533
534
535
536
537
538
539
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
rusty1s's avatar
rusty1s committed
540
541
542
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
rusty1s's avatar
rusty1s committed
543
544
    sparse_sizes = mat.shape[:2]

rusty1s's avatar
rusty1s committed
545
546
547
548
    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=None,
                            colptr=colptr, colcount=None, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
549
550
551
552
553

    return SparseTensor.from_storage(storage)


@torch.jit.ignore
rusty1s's avatar
rusty1s committed
554
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
rusty1s's avatar
rusty1s committed
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
             dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
    assert self.dim() == 2
    layout = get_layout(layout)

    if not self.has_value():
        ones = torch.ones(self.nnz(), dtype=dtype).numpy()

    if layout == 'coo':
        row, col, value = self.coo()
        row = row.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
    elif layout == 'csr':
        rowptr, col, value = self.csr()
        rowptr = rowptr.detach().cpu().numpy()
        col = col.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
    elif layout == 'csc':
        colptr, row, value = self.csc()
        colptr = colptr.detach().cpu().numpy()
        row = row.detach().cpu().numpy()
        value = value.detach().cpu().numpy() if self.has_value() else ones
        return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())


SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy