"git@developer.sourcefind.cn:modelzoo/gpt2_migraphx.git" did not exist on "816b3d5203b7d69f4cf1d457c78a82a4b69798ea"
test_elementwise_op_sp.py 1.76 KB
Newer Older
1
import sys
2

3
import backend as F
4
5
import pytest
import torch
6

7
from dgl.sparse import from_coo, power
8

Mufei Li's avatar
Mufei Li committed
9
# TODO(#4818): Skipping tests on win.
10
11
12
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)

13

14
def all_close_sparse(A, row, col, val, shape):
15
16
    rowA, colA = A.coo()
    valA = A.val
17
18
19
20
21
22
    assert torch.allclose(rowA, row)
    assert torch.allclose(colA, col)
    assert torch.allclose(valA, val)
    assert A.shape == shape


23
24
25
26
27
28
@pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_mul_scalar(v_scalar):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(len(row)).to(ctx)
29
    A1 = from_coo(row, col, val, shape=(3, 4))
30
31
32
33
34
35
36
37
38
39
40
41

    # A * v
    A2 = A1 * v_scalar
    assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)
    assert A1.shape == A2.shape

    # v * A
    A2 = v_scalar * A1
    assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)
    assert A1.shape == A2.shape


42
43
44
45
46
47
48
@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
def test_pow(val_shape):
    # A ** v
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(val_shape).to(ctx)
49
    A = from_coo(row, col, val, shape=(3, 4))
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    exponent = 2
    A_new = A**exponent
    assert torch.allclose(A_new.val, val**exponent)
    assert A_new.shape == A.shape
    new_row, new_col = A_new.coo()
    assert torch.allclose(new_row, row)
    assert torch.allclose(new_col, col)

    # power(A, v)
    A_new = power(A, exponent)
    assert torch.allclose(A_new.val, val**exponent)
    assert A_new.shape == A.shape
    new_row, new_col = A_new.coo()
    assert torch.allclose(new_row, row)
    assert torch.allclose(new_col, col)