distributed_test_base.py 5.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
import operator
import re
from functools import reduce
7
8
from itertools import product
import pytest
9
10
11
12
13
14
15
16
17
18
19
20
21

import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED

from transformer_engine.jax.sharding import MeshResource

from utils import assert_allclose, is_devices_enough


def generate_configs():
    configs = []
    if is_devices_enough(4):
        configs.append(
22
23
24
            pytest.param(
                4,
                (2, 2),
25
26
                ("dp", "tpsp"),
                MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
27
                id="n4_dp2_tp2",
28
            )
29
        )
30

31
32
33
34
35
36
37
    if is_devices_enough(2):
        configs.append(
            pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
        )
        configs.append(
            pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"),
        )
38
39
40
    return configs


41
42
43
44
def generate_context_parallel_configs_for_attn():
    """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only"""
    configsL1 = []
    configsL2 = []
45
46
    mr = MeshResource(dp_resource="dp", cp_resource="cp", tpsp_resource="tpsp")
    axes = ("dp", "cp", "tpsp")
47
48
49
50
51
52
    DP_sizes = (1, 2)
    CP_sizes = (1, 2, 4, 8)
    TP_sizes = (1, 2)
    for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
        ndev = cp * tp * dp
        if is_devices_enough(ndev):
53
54
55
56
57
58
59
60
61
62
            # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations)
            if cp != 1:
                configsL1.append(
                    pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
                )
            else:
                configsL2.append(
                    pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
                )
    configs = {"L0": [], "L1": configsL1, "L2": configsL2}
63
64
65
    return configs


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"


def generate_collectives_count(allreduce, allgather, other):
    return {COLL_AR_KEY: allreduce, COLL_AG_KEY: allgather, COLL_OTHER_KEY: other}


def assert_equal_collectives(target_hlo, coll_count_ref):
    target_splitted_hlo = target_hlo.splitlines()
    start_symb = "-start"

    def count_bytes(hlo_text):
        bytes_count = 0

        def get_bytes_per_txt(t):
83
            """
84
85
86
87
88
89
90
            The pattern of t would be like:
                'f32[]',
                '(f32[1024]{0}',
                'f32[1024]{0})',
                'f8E4M3FN[1024]{0}',
                'i32[1024]{0}',
                'bf16[1024,1024]{0}'
91
            """
92
            match = re.search(r"(i|f|u)(\d+).*\[([0-9,]*)\]", t)
93
94
            _, bits_of_type, shape = match.groups()
            bytes_of_type = int(bits_of_type) // 8
95
            if shape == "":
96
97
                num_of_elements = 1
            else:
98
                num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
99
100
101
102

            return bytes_of_type * num_of_elements

        # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
103
        if "(" in hlo_text[2]:
104
105
            for txt in hlo_text[2:]:
                bytes_count += get_bytes_per_txt(txt)
106
                if ")" in txt:
107
                    break
108
        else:  # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            bytes_count = get_bytes_per_txt(hlo_text[2])

        return bytes_count

    def count_collectives(splitted_hlo):
        result = generate_collectives_count(0, 0, 0)

        for line in splitted_hlo:
            txt = line.split()
            if len(txt) > 0 and start_symb in txt[0]:
                if COLL_AR_KEY in txt[0]:
                    result[COLL_AR_KEY] += count_bytes(txt)
                elif COLL_AG_KEY in txt[0]:
                    result[COLL_AG_KEY] += count_bytes(txt)
                else:
                    result[COLL_OTHER_KEY] += count_bytes(txt)
        return result

    target_result = count_collectives(target_splitted_hlo)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    assert (
        target_result == coll_count_ref
    ), f"Expected collective count is {coll_count_ref}, but got {target_result}."


def compare_ops(
    target_func,
    ref_func,
    inputs,
    coll_count_ref,
    *,
    grad_args=None,
    metric_fwd_dtype=None,
    metric_bwd_dtype=None,
    in_shardings=_UNSPECIFIED,
    out_shardings=_UNSPECIFIED,
    **kwargs,
):
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    assert len(inputs) >= 1

    if metric_fwd_dtype is None:
        metric_fwd_dtype = inputs[0].dtype
    if metric_bwd_dtype is None:
        metric_bwd_dtype = inputs[0].dtype

    if grad_args is None:
        grad_args = tuple(range(len(inputs)))

    target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
    target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
    target_fwd, target_grads = target_pjitter(*inputs, **kwargs)
    target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text()

    ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
    ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
    ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs)

    assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)

    for i in range(len(target_grads)):
        assert_allclose(target_grads[i], ref_grads[i], dtype=metric_bwd_dtype)

    assert_equal_collectives(target_hlo, coll_count_ref)