utils.py 6.48 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Utils for testing"""

6
import random
Tim Moon's avatar
Tim Moon committed
7
from typing import Union
8

Tim Moon's avatar
Tim Moon committed
9
import numpy as np
10
import paddle
11
12
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
13

14
import transformer_engine    # pylint: disable=unused-import
Tim Moon's avatar
Tim Moon committed
15
16
17
18
19
20
from transformer_engine.paddle.constants import (
    TE_DType,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
21
from transformer_engine.paddle.fp8 import FP8TensorMeta
22
from transformer_engine import transformer_engine_paddle as tex    # pylint: disable=wrong-import-order
23

24
25

def create_fp8_meta(num_gemms=1, amax_history_len=10):
26
27
28
    """
    Create and initialize FP8TensorMeta
    """
29
30
    fp8_meta = FP8TensorMeta(is_forward=True)
    fp8_meta.prepare(num_gemms, amax_history_len)
31
32
33
34
35
36
37
38
39
40
41
42
    return fp8_meta


def assert_allclose(actual,
                    desired,
                    rtol=1e-05,
                    atol=1e-08,
                    equal_nan=True,
                    err_msg='',
                    verbose=True):
    """Compare two input paddle tensors"""
    if isinstance(actual, paddle.Tensor):
43
        actual = paddle.cast(actual, 'float32')
44
    if isinstance(desired, paddle.Tensor):
45
46
47
48
49
50
51
        desired = paddle.cast(desired, 'float32')
    if len(actual.shape) == 0:
        actual = actual.item()
        desired = desired.item()
    else:
        actual = actual.numpy()
        desired = desired.numpy()
52
    np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


def assert_shape(inp, expected_shape):
    """Assert the shape of input tensor equals to expected shape"""
    assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
        f"actual tensor shape: {inp.shape}"


def is_devices_enough(required):
    """If the number of device is enough"""
    return paddle.device.cuda.device_count() >= required


def set_random_seed(seed):
    """Set random seed for reproducability."""
68
    fleet.meta_parallel.model_parallel_random_seed(seed)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    hcg = fleet.get_hybrid_communicate_group()
    if paddle.distributed.get_world_size() > 1:
        # obtain rank message of hybrid parallel

        mp_rank = hcg.get_model_parallel_rank()
        mp_size = hcg.get_model_parallel_world_size()

        pp_rank = hcg.get_stage_id()
        pp_size = hcg.get_pipe_parallel_world_size()

        dp_rank = hcg.get_data_parallel_rank()
        dp_size = hcg.get_data_parallel_world_size()

        sharding_rank = hcg.get_sharding_parallel_rank()
    else:
        mp_rank, mp_size = 0, 1
        pp_rank, pp_size = 0, 1
        dp_rank, dp_size = 0, 1
        sharding_rank, _ = 0, 1

    random.seed(seed + 100 * pp_rank)
    np.random.seed(seed + 100 * pp_rank)

    seed_offset = seed + 1024 + paddle.distributed.get_world_size()
    global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                   sharding_rank * (mp_size * pp_size * dp_size))

    seed_offset += paddle.distributed.get_world_size()
    local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                  sharding_rank * (mp_size * pp_size * dp_size))

    tracker = get_rng_state_tracker()
    # tracker.reset()
    if "global_seed" not in tracker.states_:
        tracker.add("global_seed", global_seed)
    if "local_seed" not in tracker.states_:
        tracker.add("local_seed", local_seed)

    paddle.seed(global_seed)
Tim Moon's avatar
Tim Moon committed
109

110

Tim Moon's avatar
Tim Moon committed
111
def get_fused_attention_backend(
112
113
    num_heads: int,
    num_gqa_groups: int,
Tim Moon's avatar
Tim Moon committed
114
115
    q_seqlen: int,
    kv_seqlen: int,
116
    head_size: int,
Tim Moon's avatar
Tim Moon committed
117
118
    dtype: Union[paddle.dtype, str],
    dropout: float,
119
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
120
121
122
123
124
125
126
127
128
129
130
131
132
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
    """Get cuDNN fused attention backend for attention config"""
    if isinstance(dtype, str):
        dtype = dict(
            float32=paddle.float32,
            bfloat16=paddle.bfloat16,
            float16=paddle.float16,
        )[dtype]
    return tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
133
        tex.get_nvte_qkv_layout(qkv_layout),
Tim Moon's avatar
Tim Moon committed
134
135
136
        AttnBiasType[bias_type],
        AttnMaskType[mask_type],
        dropout,
137
138
        num_heads,
        num_gqa_groups,
Tim Moon's avatar
Tim Moon committed
139
140
141
142
143
        q_seqlen,
        kv_seqlen,
        head_size,
    )

144

Tim Moon's avatar
Tim Moon committed
145
def is_fused_attention_supported(
146
147
    num_heads: int,
    num_gqa_groups: int,
Tim Moon's avatar
Tim Moon committed
148
149
    q_seqlen: int,
    kv_seqlen: int,
150
    head_size: int,
Tim Moon's avatar
Tim Moon committed
151
152
    dtype: Union[paddle.dtype, str],
    dropout: float,
153
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
154
155
156
157
158
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> bool:
    """Check if cuDNN fused attention is supported for attention config"""
    backend = get_fused_attention_backend(
159
160
        num_heads=num_heads,
        num_gqa_groups=num_gqa_groups,
Tim Moon's avatar
Tim Moon committed
161
162
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
163
        head_size=head_size,
Tim Moon's avatar
Tim Moon committed
164
165
166
167
168
169
170
        dtype=dtype,
        dropout=dropout,
        qkv_layout=qkv_layout,
        bias_type=bias_type,
        mask_type=mask_type,
    )
    return backend != FusedAttnBackend["No_Backend"]
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213


def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None:
    """Register allreduce hooks for sequence parallel tensors"""

    def is_sequence_parallel_parameter(parameter):
        """If input tensor is marked as sequence parallel tensor"""
        out = getattr(parameter, "sequence_parallel", False)
        return out

    def create_allreduce_gradient_hook(param, accumulation_steps):
        """Create allreduce gradient hook"""
        hcg = fleet.get_hybrid_communicate_group()
        pg = hcg.get_model_parallel_group().process_group
        step = [0]

        @paddle.autograd.no_grad()
        def __impl__():
            step[0] += 1
            if (step[0] % accumulation_steps) == 0:
                if hasattr(param, "main_grad"):
                    pg.allreduce(param.main_grad).wait()
                else:
                    pg.allreduce(param.grad).wait()

        return __impl__

    if accumulation_steps <= 0 or not paddle.distributed.is_initialized():
        return

    hcg = fleet.get_hybrid_communicate_group()
    mp_group = hcg.get_model_parallel_group()
    if mp_group.nranks <= 1:
        return

    params = []
    for p in model.parameters():
        if is_sequence_parallel_parameter(p):
            params.append(p)

    for p in params:
        hook = create_allreduce_gradient_hook(p, accumulation_steps)
        p._register_backward_hook(hook)