test_basic_ops.py 542 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from fastfold.model.fastnn.ops import Linear as FastLinear
from fastfold.model.nn.primitives import Linear


def test_linear():    
    c_in = 3
    c_out = 4
    seq = 5
    
    fast_linear = FastLinear(c_in, c_out).cuda()
    linear = Linear(c_in, c_out).cuda()
    
    fast_linear.weight = linear.weight
    fast_linear.bias = linear.bias

    x = torch.randn((seq, c_in)).cuda()

    out1 = fast_linear(x)
    out2 = linear(x)
    assert torch.allclose(out1, out2, atol=1e-8)


if __name__ == "__main__":
    test_linear()