test_unary_op_diag.py 569 Bytes
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
import sys

import backend as F
4
5
import pytest
import torch
Mufei Li's avatar
Mufei Li committed
6

7
from dgl.sparse import diag
Mufei Li's avatar
Mufei Li committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


def test_neg():
    ctx = F.ctx()
    val = torch.arange(3).float().to(ctx)
    D = diag(val)
    neg_D = -D
    assert D.shape == neg_D.shape
    assert torch.allclose(-D.val, neg_D.val)
    assert D.val.device == neg_D.val.device


def test_inv():
    ctx = F.ctx()
    val = torch.arange(1, 4).float().to(ctx)
    D = diag(val)
    inv_D = D.inv()
    assert D.shape == inv_D.shape
    assert torch.allclose(1.0 / D.val, inv_D.val)
    assert D.val.device == inv_D.val.device