"src/diffusers/schedulers/scheduling_pndm_flax.py" did not exist on "ad9d252596c33ce80275a866d970dd4242fd56f0"
test_storage.py 5.88 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_sparse.storage import SparseStorage
rusty1s's avatar
rusty1s committed
6
7
8
9

from .utils import dtypes, devices, tensor


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@pytest.mark.parametrize('device', devices)
def test_ind2ptr(device):
    row = tensor([2, 2, 4, 5, 5, 6], torch.long, device)
    rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
    assert rowptr.tolist() == [0, 0, 0, 2, 2, 3, 5, 6, 6]

    row = torch.ops.torch_sparse.ptr2ind(rowptr, 6)
    assert row.tolist() == [2, 2, 4, 5, 5, 6]

    row = tensor([], torch.long, device)
    rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
    assert rowptr.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0]

    row = torch.ops.torch_sparse.ptr2ind(rowptr, 0)
    assert row.tolist() == []


rusty1s's avatar
rusty1s committed
27
28
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
rusty1s's avatar
rusty1s committed
29
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
    storage = SparseStorage(row=row, col=col)
rusty1s's avatar
rusty1s committed
32
33
34
35
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value() is None
    assert storage.sparse_sizes() == (2, 2)
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
rusty1s's avatar
rusty1s committed
38
    value = tensor([2, 1, 4, 3], dtype, device)
rusty1s's avatar
rusty1s committed
39
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
40
41
42
43
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value().tolist() == [1, 2, 3, 4]
    assert storage.sparse_sizes() == (2, 2)
rusty1s's avatar
rusty1s committed
44
45
46
47


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
rusty1s's avatar
rusty1s committed
48
49
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
    storage = SparseStorage(row=row, col=col)
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
    assert storage._row.tolist() == row.tolist()
    assert storage._col.tolist() == col.tolist()
rusty1s's avatar
rusty1s committed
53
54
55
56
57
58
59
    assert storage._value is None

    assert storage._rowcount is None
    assert storage._rowptr is None
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
rusty1s's avatar
rusty1s committed
60
    assert storage.num_cached_keys() == 0
rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
68

    storage.fill_cache_()
    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
rusty1s's avatar
rusty1s committed
69
    assert storage.num_cached_keys() == 5
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
72
73
74
75
76
    storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col,
                            value=storage._value,
                            sparse_sizes=storage._sparse_sizes,
                            rowcount=storage._rowcount, colptr=storage._colptr,
                            colcount=storage._colcount,
                            csr2csc=storage._csr2csc, csc2csr=storage._csc2csr)
rusty1s's avatar
rusty1s committed
77
78
79
80
81
82
83

    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
rusty1s's avatar
rusty1s committed
84
    assert storage.num_cached_keys() == 5
rusty1s's avatar
rusty1s committed
85
86
87

    storage.clear_cache_()
    assert storage._rowcount is None
rusty1s's avatar
rusty1s committed
88
    assert storage._rowptr is not None
rusty1s's avatar
rusty1s committed
89
90
91
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
rusty1s's avatar
rusty1s committed
92
    assert storage.num_cached_keys() == 0
rusty1s's avatar
rusty1s committed
93
94
95
96


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
rusty1s's avatar
rusty1s committed
97
    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
rusty1s's avatar
rusty1s committed
98
    value = tensor([1, 2, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
99
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
100
101
102
103

    assert storage.has_value()

    storage.set_value_(value, layout='csc')
rusty1s's avatar
rusty1s committed
104
    assert storage.value().tolist() == [1, 3, 2, 4]
rusty1s's avatar
rusty1s committed
105
    storage.set_value_(value, layout='coo')
rusty1s's avatar
rusty1s committed
106
    assert storage.value().tolist() == [1, 2, 3, 4]
rusty1s's avatar
rusty1s committed
107
108

    storage = storage.set_value(value, layout='csc')
rusty1s's avatar
rusty1s committed
109
    assert storage.value().tolist() == [1, 3, 2, 4]
rusty1s's avatar
rusty1s committed
110
    storage = storage.set_value(value, layout='coo')
rusty1s's avatar
rusty1s committed
111
    assert storage.value().tolist() == [1, 2, 3, 4]
rusty1s's avatar
rusty1s committed
112

rusty1s's avatar
rusty1s committed
113
114
    storage = storage.sparse_resize((3, 3))
    assert storage.sparse_sizes() == (3, 3)
rusty1s's avatar
rusty1s committed
115

rusty1s's avatar
rusty1s committed
116
    new_storage = storage.copy()
rusty1s's avatar
rusty1s committed
117
    assert new_storage != storage
rusty1s's avatar
rusty1s committed
118
    assert new_storage.col().data_ptr() == storage.col().data_ptr()
rusty1s's avatar
rusty1s committed
119
120
121

    new_storage = storage.clone()
    assert new_storage != storage
rusty1s's avatar
rusty1s committed
122
    assert new_storage.col().data_ptr() != storage.col().data_ptr()
rusty1s's avatar
rusty1s committed
123
124
125
126


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
rusty1s's avatar
rusty1s committed
127
    row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
rusty1s's avatar
rusty1s committed
128
    value = tensor([1, 1, 1, 3, 4], dtype, device)
rusty1s's avatar
rusty1s committed
129
    storage = SparseStorage(row=row, col=col, value=value)
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
rusty1s committed
131
132
133
    assert storage.row().tolist() == row.tolist()
    assert storage.col().tolist() == col.tolist()
    assert storage.value().tolist() == value.tolist()
rusty1s's avatar
rusty1s committed
134
135
136
137
138

    assert not storage.is_coalesced()
    storage = storage.coalesce()
    assert storage.is_coalesced()

rusty1s's avatar
rusty1s committed
139
140
141
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value().tolist() == [1, 2, 3, 4]
rusty1s's avatar
rusty1s committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_reshape(dtype, device):
    row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
    storage = SparseStorage(row=row, col=col)

    storage = storage.sparse_reshape(2, 8)
    assert storage.sparse_sizes() == (2, 8)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 5, 2, 7]

    storage = storage.sparse_reshape(-1, 4)
    assert storage.sparse_sizes() == (4, 4)
    assert storage.row().tolist() == [0, 1, 2, 3]
    assert storage.col().tolist() == [0, 1, 2, 3]

    storage = storage.sparse_reshape(2, -1)
    assert storage.sparse_sizes() == (2, 8)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 5, 2, 7]