test_layernorm.py 1.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from fastfold.model.fastnn.kernel import LayerNorm as FastLayerNorm


def test_layernorm():

    # [batch, dim]
    test_shape = [[64, 64], [64, 128], [64, 129], [64, 1024]]
    test_dtype = [torch.float32, torch.float16, torch.bfloat16]
    test_device = torch.device("cuda")

    tolerance_eps = {torch.float32: 10e-5, torch.float16: 10e-2, torch.bfloat16: 10e-2}

    for shape in test_shape:
        for dtype in test_dtype:
            sample_input = torch.rand(shape).to(device=test_device,
                                                dtype=dtype).requires_grad_(False)

            dim_ = sample_input.size()[-1]
            torch_module = torch.nn.LayerNorm(normalized_shape=dim_).to(device=test_device,
                                                                        dtype=dtype)
            fastnn_module = FastLayerNorm(normalized_shape=dim_).to(device=test_device, dtype=dtype)

            # Forward
            torch_out = torch_module(sample_input)
            fastnn_out = fastnn_module(sample_input)
            forward_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item()
            assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"

            # Backward
            out_grad = torch.rand_like(torch_out).requires_grad_(False)
            torch_out.backward(out_grad)
            fastnn_out.backward(out_grad)

            backward_weight_error = torch.max(
                torch.abs(torch_module.weight.grad - fastnn_module.weight.grad)).cpu().item()
            assert backward_weight_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"

            backward_bias_error = torch.max(
                torch.abs(torch_module.bias.grad - fastnn_module.bias.grad)).cpu().item()
            assert backward_bias_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"