distributed_test_base.py 4.85 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
import operator
import re
from functools import reduce
7
8
from itertools import product
import pytest
9
10
11
12
13
14
15
16
17
18
19
20

import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED

from transformer_engine.jax.sharding import MeshResource

from utils import assert_allclose, is_devices_enough


def generate_configs():
    configs = []
    if is_devices_enough(2):
21
22
23
        configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
        configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])

24
25
26
27
    if is_devices_enough(4):
        TP_size = 2
        DP_size = 2
        configs.append(
28
29
            [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
        )
30
31
32
33

    return configs


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def generate_context_parallel_configs():
    configs = []

    DP_sizes = (1, 2)
    CP_sizes = (1, 2, 4, 8)
    TP_sizes = (1, 2)
    for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
        ndev = cp * tp * dp
        if is_devices_enough(ndev):
            configs.append(
                pytest.param(
                    ndev,
                    (dp, cp, tp),
                    ("dp", "cp", "tp"),
                    MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
                    id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
                )
            )

    return configs


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"


def generate_collectives_count(allreduce, allgather, other):
    return {COLL_AR_KEY: allreduce, COLL_AG_KEY: allgather, COLL_OTHER_KEY: other}


def assert_equal_collectives(target_hlo, coll_count_ref):
    target_splitted_hlo = target_hlo.splitlines()
    start_symb = "-start"

    def count_bytes(hlo_text):
        bytes_count = 0

        def get_bytes_per_txt(t):
73
            """
74
75
76
77
78
79
80
            The pattern of t would be like:
                'f32[]',
                '(f32[1024]{0}',
                'f32[1024]{0})',
                'f8E4M3FN[1024]{0}',
                'i32[1024]{0}',
                'bf16[1024,1024]{0}'
81
82
            """
            match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
83
84
            _, bits_of_type, shape = match.groups()
            bytes_of_type = int(bits_of_type) // 8
85
            if shape == "":
86
87
                num_of_elements = 1
            else:
88
                num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
89
90
91
92

            return bytes_of_type * num_of_elements

        # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
93
        if "(" in hlo_text[2]:
94
95
            for txt in hlo_text[2:]:
                bytes_count += get_bytes_per_txt(txt)
96
                if ")" in txt:
97
                    break
98
        else:  # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            bytes_count = get_bytes_per_txt(hlo_text[2])

        return bytes_count

    def count_collectives(splitted_hlo):
        result = generate_collectives_count(0, 0, 0)

        for line in splitted_hlo:
            txt = line.split()
            if len(txt) > 0 and start_symb in txt[0]:
                if COLL_AR_KEY in txt[0]:
                    result[COLL_AR_KEY] += count_bytes(txt)
                elif COLL_AG_KEY in txt[0]:
                    result[COLL_AG_KEY] += count_bytes(txt)
                else:
                    result[COLL_OTHER_KEY] += count_bytes(txt)
        return result

    target_result = count_collectives(target_splitted_hlo)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    assert (
        target_result == coll_count_ref
    ), f"Expected collective count is {coll_count_ref}, but got {target_result}."


def compare_ops(
    target_func,
    ref_func,
    inputs,
    coll_count_ref,
    *,
    grad_args=None,
    metric_fwd_dtype=None,
    metric_bwd_dtype=None,
    in_shardings=_UNSPECIFIED,
    out_shardings=_UNSPECIFIED,
    **kwargs,
):
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    assert len(inputs) >= 1

    if metric_fwd_dtype is None:
        metric_fwd_dtype = inputs[0].dtype
    if metric_bwd_dtype is None:
        metric_bwd_dtype = inputs[0].dtype

    if grad_args is None:
        grad_args = tuple(range(len(inputs)))

    target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
    target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
    target_fwd, target_grads = target_pjitter(*inputs, **kwargs)
    target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text()

    ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
    ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
    ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs)

    assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)

    for i in range(len(target_grads)):
        assert_allclose(target_grads[i], ref_grads[i], dtype=metric_bwd_dtype)

    assert_equal_collectives(target_hlo, coll_count_ref)