"tests/test_structures/test_bbox/test_bbox_coders.py" did not exist on "d07bd8bd20abacb9b8270454eb94b70e3876081c"
test_transpose.py 1.53 KB
Newer Older
1
2
3
4
5
6
7
8
import pytest
import torch
import sys

from dgl.mock_sparse2 import diag, create_from_coo

import backend as F

Mufei Li's avatar
Mufei Li committed
9
# TODO(#4818): Skipping tests on win.
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)

@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag_matrix_transpose(val_shape, mat_shape):
    ctx = F.ctx()
    val = torch.randn(val_shape).to(ctx)
    mat = diag(val, mat_shape).transpose()

    assert torch.allclose(mat.val, val)
    if mat_shape is None:
        mat_shape = (val_shape[0], val_shape[0])
    assert mat.shape == mat_shape[::-1]


@pytest.mark.parametrize("dense_dim", [None, 2])
@pytest.mark.parametrize("row", [[0, 0, 1, 2], (0, 1, 2, 4)])
@pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)])
@pytest.mark.parametrize("extra_shape", [(0, 1), (2, 1)])
def test_sparse_matrix_transpose(dense_dim, row, col, extra_shape):
    mat_shape = (max(row) + 1 + extra_shape[0], max(col) + 1 + extra_shape[1])
    val_shape = (len(row),)
    if dense_dim is not None:
        val_shape += (dense_dim,)
    ctx = F.ctx()
    val = torch.randn(val_shape).to(ctx)
    row = torch.tensor(row).to(ctx)
    col = torch.tensor(col).to(ctx)
    mat = create_from_coo(row, col, val, mat_shape).transpose()
    mat_row, mat_col = mat.coo()
    mat_val = mat.val

    assert mat.shape == mat_shape[::-1]
    assert torch.allclose(mat_val, val)
    assert torch.allclose(mat_row, col)
    assert torch.allclose(mat_col, row)