utils.py 6.42 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Utils for testing"""

6
import random
Tim Moon's avatar
Tim Moon committed
7
from typing import Union
8

Tim Moon's avatar
Tim Moon committed
9
import numpy as np
10
import paddle
11
12
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
13

14
import transformer_engine  # pylint: disable=unused-import
Tim Moon's avatar
Tim Moon committed
15
16
17
18
19
20
from transformer_engine.paddle.constants import (
    TE_DType,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
21
from transformer_engine.paddle.fp8 import FP8TensorMeta
22
23
24
from transformer_engine import (
    transformer_engine_paddle as tex,
)  # pylint: disable=wrong-import-order
25

26
27

def create_fp8_meta(num_gemms=1, amax_history_len=10):
28
29
30
    """
    Create and initialize FP8TensorMeta
    """
31
32
    fp8_meta = FP8TensorMeta(is_forward=True)
    fp8_meta.prepare(num_gemms, amax_history_len)
33
34
35
    return fp8_meta


36
37
38
def assert_allclose(
    actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True
):
39
40
    """Compare two input paddle tensors"""
    if isinstance(actual, paddle.Tensor):
41
        actual = paddle.cast(actual, "float32")
42
    if isinstance(desired, paddle.Tensor):
43
        desired = paddle.cast(desired, "float32")
44
45
46
47
48
49
    if len(actual.shape) == 0:
        actual = actual.item()
        desired = desired.item()
    else:
        actual = actual.numpy()
        desired = desired.numpy()
50
    np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
51
52
53
54


def assert_shape(inp, expected_shape):
    """Assert the shape of input tensor equals to expected shape"""
55
56
57
    assert (
        inp.shape == expected_shape
    ), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}"
58
59
60
61
62
63
64
65
66


def is_devices_enough(required):
    """If the number of device is enough"""
    return paddle.device.cuda.device_count() >= required


def set_random_seed(seed):
    """Set random seed for reproducability."""
67
    fleet.meta_parallel.model_parallel_random_seed(seed)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    hcg = fleet.get_hybrid_communicate_group()
    if paddle.distributed.get_world_size() > 1:
        # obtain rank message of hybrid parallel

        mp_rank = hcg.get_model_parallel_rank()
        mp_size = hcg.get_model_parallel_world_size()

        pp_rank = hcg.get_stage_id()
        pp_size = hcg.get_pipe_parallel_world_size()

        dp_rank = hcg.get_data_parallel_rank()
        dp_size = hcg.get_data_parallel_world_size()

        sharding_rank = hcg.get_sharding_parallel_rank()
    else:
        mp_rank, mp_size = 0, 1
        pp_rank, pp_size = 0, 1
        dp_rank, dp_size = 0, 1
        sharding_rank, _ = 0, 1

    random.seed(seed + 100 * pp_rank)
    np.random.seed(seed + 100 * pp_rank)

    seed_offset = seed + 1024 + paddle.distributed.get_world_size()
93
94
95
96
97
98
    global_seed = (
        seed_offset
        + pp_rank * (mp_size)
        + dp_rank * (mp_size * pp_size)
        + sharding_rank * (mp_size * pp_size * dp_size)
    )
99
100

    seed_offset += paddle.distributed.get_world_size()
101
102
103
104
105
106
107
    local_seed = (
        seed_offset
        + mp_rank
        + pp_rank * (mp_size)
        + dp_rank * (mp_size * pp_size)
        + sharding_rank * (mp_size * pp_size * dp_size)
    )
108
109
110
111
112
113
114
115
116

    tracker = get_rng_state_tracker()
    # tracker.reset()
    if "global_seed" not in tracker.states_:
        tracker.add("global_seed", global_seed)
    if "local_seed" not in tracker.states_:
        tracker.add("local_seed", local_seed)

    paddle.seed(global_seed)
Tim Moon's avatar
Tim Moon committed
117

118

Tim Moon's avatar
Tim Moon committed
119
def get_fused_attention_backend(
120
121
    num_heads: int,
    num_gqa_groups: int,
Tim Moon's avatar
Tim Moon committed
122
123
    q_seqlen: int,
    kv_seqlen: int,
124
    head_size: int,
Tim Moon's avatar
Tim Moon committed
125
126
    dtype: Union[paddle.dtype, str],
    dropout: float,
127
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
128
129
130
131
132
133
134
135
136
137
138
139
140
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
    """Get cuDNN fused attention backend for attention config"""
    if isinstance(dtype, str):
        dtype = dict(
            float32=paddle.float32,
            bfloat16=paddle.bfloat16,
            float16=paddle.float16,
        )[dtype]
    return tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
141
        tex.get_nvte_qkv_layout(qkv_layout),
Tim Moon's avatar
Tim Moon committed
142
143
144
        AttnBiasType[bias_type],
        AttnMaskType[mask_type],
        dropout,
145
146
        num_heads,
        num_gqa_groups,
Tim Moon's avatar
Tim Moon committed
147
148
149
150
151
        q_seqlen,
        kv_seqlen,
        head_size,
    )

152

Tim Moon's avatar
Tim Moon committed
153
def is_fused_attention_supported(
154
155
    num_heads: int,
    num_gqa_groups: int,
Tim Moon's avatar
Tim Moon committed
156
157
    q_seqlen: int,
    kv_seqlen: int,
158
    head_size: int,
Tim Moon's avatar
Tim Moon committed
159
160
    dtype: Union[paddle.dtype, str],
    dropout: float,
161
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
162
163
164
165
166
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> bool:
    """Check if cuDNN fused attention is supported for attention config"""
    backend = get_fused_attention_backend(
167
168
        num_heads=num_heads,
        num_gqa_groups=num_gqa_groups,
Tim Moon's avatar
Tim Moon committed
169
170
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
171
        head_size=head_size,
Tim Moon's avatar
Tim Moon committed
172
173
174
175
176
177
178
        dtype=dtype,
        dropout=dropout,
        qkv_layout=qkv_layout,
        bias_type=bias_type,
        mask_type=mask_type,
    )
    return backend != FusedAttnBackend["No_Backend"]
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221


def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None:
    """Register allreduce hooks for sequence parallel tensors"""

    def is_sequence_parallel_parameter(parameter):
        """If input tensor is marked as sequence parallel tensor"""
        out = getattr(parameter, "sequence_parallel", False)
        return out

    def create_allreduce_gradient_hook(param, accumulation_steps):
        """Create allreduce gradient hook"""
        hcg = fleet.get_hybrid_communicate_group()
        pg = hcg.get_model_parallel_group().process_group
        step = [0]

        @paddle.autograd.no_grad()
        def __impl__():
            step[0] += 1
            if (step[0] % accumulation_steps) == 0:
                if hasattr(param, "main_grad"):
                    pg.allreduce(param.main_grad).wait()
                else:
                    pg.allreduce(param.grad).wait()

        return __impl__

    if accumulation_steps <= 0 or not paddle.distributed.is_initialized():
        return

    hcg = fleet.get_hybrid_communicate_group()
    mp_group = hcg.get_model_parallel_group()
    if mp_group.nranks <= 1:
        return

    params = []
    for p in model.parameters():
        if is_sequence_parallel_parameter(p):
            params.append(p)

    for p in params:
        hook = create_allreduce_gradient_hook(p, accumulation_steps)
        p._register_backward_hook(hook)