test_layernorm.py 3.51 KB
Newer Older
1
2
import torch
from fastfold.model.fastnn.kernel import LayerNorm as FastLayerNorm
3
4
5
6
7
8
9
10
from fastfold.model.fastnn.kernel.layer_norm import FusedLayerNormAffineFunction

triton = True
try:
    from fastfold.model.fastnn.kernel.layer_norm import LayerNormTritonFunc
except:
    print("Skip triton layernorm test!")
    triton = False
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


def test_layernorm():

    # [batch, dim]
    test_shape = [[64, 64], [64, 128], [64, 129], [64, 1024]]
    test_dtype = [torch.float32, torch.float16, torch.bfloat16]
    test_device = torch.device("cuda")

    tolerance_eps = {torch.float32: 10e-5, torch.float16: 10e-2, torch.bfloat16: 10e-2}

    for shape in test_shape:
        for dtype in test_dtype:
            sample_input = torch.rand(shape).to(device=test_device,
                                                dtype=dtype).requires_grad_(False)

            dim_ = sample_input.size()[-1]
            torch_module = torch.nn.LayerNorm(normalized_shape=dim_).to(device=test_device,
                                                                        dtype=dtype)
30
31
32
            fastnn_cuda_module = FastLayerNorm(normalized_shape=dim_).to(device=test_device, dtype=dtype)
            if triton:
                fastnn_triton_module = FastLayerNorm(normalized_shape=dim_).to(device=test_device, dtype=dtype)
33
34
35

            # Forward
            torch_out = torch_module(sample_input)
36
37
38
39
            
            fastnn_cuda_out = FusedLayerNormAffineFunction.apply(sample_input, fastnn_cuda_module.weight, fastnn_cuda_module.bias, 
                                                                 fastnn_cuda_module.normalized_shape, fastnn_cuda_module.eps)
            forward_error = torch.max(torch.abs(torch_out - fastnn_cuda_out)).cpu().item()
40
            assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
41
42
43
44
45
46
            
            if triton:
                fastnn_triton_out = LayerNormTritonFunc.apply(sample_input, fastnn_triton_module.normalized_shape, fastnn_triton_module.weight, 
                                                            fastnn_triton_module.bias, fastnn_triton_module.eps)
                forward_error = torch.max(torch.abs(torch_out - fastnn_triton_out)).cpu().item()
                assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
47
48
49
50

            # Backward
            out_grad = torch.rand_like(torch_out).requires_grad_(False)
            torch_out.backward(out_grad)
51
            fastnn_cuda_out.backward(out_grad)
52
53

            backward_weight_error = torch.max(
54
                torch.abs(torch_module.weight.grad - fastnn_cuda_module.weight.grad)).cpu().item()
55
56
            assert backward_weight_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
            backward_bias_error = torch.max(
57
                torch.abs(torch_module.bias.grad - fastnn_cuda_module.bias.grad)).cpu().item()
58
            assert backward_bias_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
59
60
61
62
63
64
65
66
67
68
69
70
71

            if triton:
                fastnn_triton_out.backward(out_grad)
                backward_weight_error = torch.max(
                    torch.abs(torch_module.weight.grad - fastnn_triton_module.weight.grad)).cpu().item()
                assert backward_weight_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
                backward_bias_error = torch.max(
                    torch.abs(torch_module.bias.grad - fastnn_triton_module.bias.grad)).cpu().item()
                assert backward_bias_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"


if __name__ == "__main__":
    test_layernorm()