test_torch_softmax.py 1.8 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import pytest
import torch

from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.utils import attention_mask_func, get_default_causal_mask


class TestTorchSoftmax:
    def setup_method(self, method):
        # The important settings tested are forward_torch_softmax path
        # with locally generated casual mask for attention_mask_func:
        self.softmax = FusedScaleMaskSoftmax(
            input_in_fp16=False,
            input_in_bf16=False,
            attn_mask_type=AttnMaskType.causal,
            scaled_masked_softmax_fusion=False,
            mask_func=attention_mask_func,
            softmax_in_fp32=True,
            scale=None,
        )

    def teardown_method(self):
        get_default_causal_mask.cache_clear()

    def test_output_shape(self):
        x = torch.randn(8, 2, 4, 4, device="cuda")
        y = self.softmax(x, None)
        assert x.shape == y.shape

    def test_causal_mask_input_shape_assert(self):
        x = torch.randn(1, 1, 4, 16, device="cuda")
        with pytest.raises(AssertionError):
            self.softmax(x, None)

    def test_causal_mask_equal_scores(self):
        # For equal input values (e.g. zero) correctly masked softmax should
        # produce equal scores among non-masked elements. For example, in case
        # sq == sk == 2 the expected output is (ignoring b and np dimensions):
        # [[1.0, 0.0],
        #  [0.5, 0.5]]
        b, np, sq, sk = 8, 2, 32, 32
        x = torch.zeros([b, np, sq, sk]).cuda()
        y = self.softmax(x, None)
        y_expected = torch.tril(torch.ones(b, np, sq, sk, device="cuda"))
        y_expected /= torch.arange(1, sq + 1, device="cuda").reshape((-1, 1))
        assert torch.allclose(y, y_expected, rtol=1e-08, atol=1e-08)