test_distributed_softmax.py 3.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSoftmax:

    def generate_collectives_count_ref(self):
        # for loss
        all_reduce_loss_bytes = 4    # 1 * FP32
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, softmax_type, dtype):
        batch, _, sqelen, _ = shape

        x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
            mask = make_causal_mask(batch, sqelen)
        else:
            mask = make_self_mask(batch, sqelen)

        x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, None, None)
        mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)

        return (x, mask), (x_pspec, mask_pspec)

    @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
    @pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
    @pytest.mark.parametrize(
        'softmax_type',
        [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
    @pytest.mark.parametrize('scale_factor', [1.0, 3.0])
    @pytest.mark.parametrize('dtype', DTYPES)
    def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
                     softmax_type, scale_factor, dtype):

        def target_func(x, mask):
            return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))

        def ref_func(x, mask):
            bias = None
            if mask is not None:
                bias = jax.lax.select(mask > 0,
                                      jnp.full(mask.shape, -1e10).astype(dtype),
                                      jnp.full(mask.shape, 0.).astype(dtype))
            if bias is not None:
                x = x + bias.astype(dtype)
            output = jax.nn.softmax(x * scale_factor)
            return jnp.mean(output)

        (x, mask), (x_pspec, mask_pspec) = \
                self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype)
        collective_count_ref = self.generate_collectives_count_ref()
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))

            compare_ops(target_func,
                        ref_func, [x_, mask_],
                        collective_count_ref,
                        grad_args=(0,),
                        metric_fwd_dtype=dtype,
                        metric_bwd_dtype=dtype,
                        in_shardings=(x_pspec, mask_pspec),
                        out_shardings=(None, (x_pspec,)))