test_softmax.py 2.73 KB
Newer Older
1
import torch
2
3

from fastfold.model.fastnn.kernel import fused_softmax
4
5
6
from fastfold.model.fastnn.kernel import softmax


7
def _test_softmax_core():
8

9
10
    batch_, chunk_, head_ = 1, 8, 4
    test_seq_ = [31, 32, 128, 129, 256, 259, 512, 700, 1024]
11
12
13
    test_dtype = [torch.float32, torch.float16, torch.bfloat16]
    test_device = torch.device("cuda")

14
    tolerance_eps = {torch.float32: 1e-6, torch.float16: 2e-4, torch.bfloat16: 1e-3}
15

16
    for seq_ in test_seq_:
17
        for dtype in test_dtype:
18
19
20
21
22
23
            sample_input = torch.rand(batch_, chunk_, head_, seq_,
                                      seq_).to(device=test_device, dtype=dtype).requires_grad_(True)
            sample_mask = torch.cuda.FloatTensor(batch_, chunk_, seq_).uniform_() > 0
            sample_mask = sample_mask.to(device=test_device, dtype=dtype).requires_grad_(False)
            sample_bias = torch.rand(batch_, 1, head_, seq_,
                                     seq_).to(device=test_device, dtype=dtype).requires_grad_(True)
24
25

            sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True)
26
27
            sample_mask_fastnn = torch.clone(sample_mask.detach()).requires_grad_(False)
            sample_bias_fastnn = torch.clone(sample_bias.detach()).requires_grad_(True)
28
29

            # Forward
30
31
32
33
34
35
36
37
38
            sample_mask_torch = 1e9 * (sample_mask - 1)[:, :, None, None, :]
            torch_out = torch.nn.functional.softmax(sample_input + sample_mask_torch + sample_bias,
                                                    dim=-1)

            fastnn_out = fused_softmax(sample_input_fastnn, sample_mask_fastnn, sample_bias_fastnn)

            fwd_fastnn_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item()
            assert fwd_fastnn_error < tolerance_eps[
                dtype], f"fastnn fwd kernel error when {seq_} {dtype}"
39
40
41
42
43
44

            # Backward
            out_grad = torch.rand_like(torch_out).requires_grad_(False)
            torch_out.backward(out_grad)
            fastnn_out.backward(out_grad)

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
            grad_input_error = torch.max(torch.abs(sample_input.grad -
                                                   sample_input_fastnn.grad)).cpu().item()
            assert grad_input_error < tolerance_eps[
                dtype], f"fastnn bwd kernel error when {seq_} {dtype}"

            grad_bias_error = torch.max(torch.abs(sample_bias.grad -
                                                  sample_bias_fastnn.grad)).cpu().item()
            assert grad_bias_error < tolerance_eps[
                dtype], f"fastnn bwd kernel error when {seq_} {dtype}"


def test_softmax():
    _test_softmax_core()
    if softmax._triton_available:
        softmax._triton_available = False
        _test_softmax_core()

if __name__ == "__main__":
    test_softmax()