"src/diffusers/models/controlnet.py" did not exist on "3ceaa280bd1550bc17cd8268cc34278b7f0b9070"
test_softmax.py 978 Bytes
Newer Older
1
2
3
4
5
6
7
import sys

import backend as F

import dgl
import pytest
import torch
8
from dgl.sparse import from_coo, softmax
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


@pytest.mark.parametrize("val_D", [None, 2])
@pytest.mark.parametrize("csr", [True, False])
def test_softmax(val_D, csr):
    dev = F.ctx()
    row = torch.tensor([0, 0, 1, 1]).to(dev)
    col = torch.tensor([0, 2, 1, 2]).to(dev)
    nnz = len(row)
    if val_D is None:
        val = torch.randn(nnz).to(dev)
    else:
        val = torch.randn(nnz, val_D).to(dev)

    val_sparse = val.clone().requires_grad_()
24
    A = from_coo(row, col, val_sparse)
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

    if csr:
        # Test CSR
        A.csr()

    A_max = softmax(A)
    g = dgl.graph((col, row), num_nodes=max(A.shape))
    val_g = val.clone().requires_grad_()
    score = dgl.nn.functional.edge_softmax(g, val_g)
    assert torch.allclose(A_max.val, score)

    grad = torch.randn_like(score).to(dev)
    A_max.val.backward(grad)
    score.backward(grad)
    assert torch.allclose(A.val.grad, val_g.grad)