utils.py 4.67 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils for testing"""

6
import random
Tim Moon's avatar
Tim Moon committed
7
from typing import Union
8

Tim Moon's avatar
Tim Moon committed
9
import numpy as np
10
import paddle
11
12
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
13

14
import transformer_engine    # pylint: disable=unused-import
Tim Moon's avatar
Tim Moon committed
15
16
17
18
19
20
from transformer_engine.paddle.constants import (
    TE_DType,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
21
from transformer_engine.paddle.fp8 import FP8TensorMeta
22
import transformer_engine_paddle as tex    # pylint: disable=wrong-import-order
23

24
25

def create_fp8_meta(num_gemms=1, amax_history_len=10):
26
27
28
    """
    Create and initialize FP8TensorMeta
    """
29
30
    fp8_meta = FP8TensorMeta(is_forward=True)
    fp8_meta.prepare(num_gemms, amax_history_len)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    return fp8_meta


def assert_allclose(actual,
                    desired,
                    rtol=1e-05,
                    atol=1e-08,
                    equal_nan=True,
                    err_msg='',
                    verbose=True):
    """Compare two input paddle tensors"""
    if isinstance(actual, paddle.Tensor):
        actual = paddle.cast(actual, 'float32').numpy()
    if isinstance(desired, paddle.Tensor):
        desired = paddle.cast(desired, 'float32').numpy()
    np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def assert_shape(inp, expected_shape):
    """Assert the shape of input tensor equals to expected shape"""
    assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
        f"actual tensor shape: {inp.shape}"


def is_devices_enough(required):
    """If the number of device is enough"""
    return paddle.device.cuda.device_count() >= required


def set_random_seed(seed):
    """Set random seed for reproducability."""
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    hcg = fleet.get_hybrid_communicate_group()
    if paddle.distributed.get_world_size() > 1:
        # obtain rank message of hybrid parallel

        mp_rank = hcg.get_model_parallel_rank()
        mp_size = hcg.get_model_parallel_world_size()

        pp_rank = hcg.get_stage_id()
        pp_size = hcg.get_pipe_parallel_world_size()

        dp_rank = hcg.get_data_parallel_rank()
        dp_size = hcg.get_data_parallel_world_size()

        sharding_rank = hcg.get_sharding_parallel_rank()
    else:
        mp_rank, mp_size = 0, 1
        pp_rank, pp_size = 0, 1
        dp_rank, dp_size = 0, 1
        sharding_rank, _ = 0, 1

    random.seed(seed + 100 * pp_rank)
    np.random.seed(seed + 100 * pp_rank)

    seed_offset = seed + 1024 + paddle.distributed.get_world_size()
    global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                   sharding_rank * (mp_size * pp_size * dp_size))

    seed_offset += paddle.distributed.get_world_size()
    local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                  sharding_rank * (mp_size * pp_size * dp_size))

    tracker = get_rng_state_tracker()
    # tracker.reset()
    if "global_seed" not in tracker.states_:
        tracker.add("global_seed", global_seed)
    if "local_seed" not in tracker.states_:
        tracker.add("local_seed", local_seed)

    paddle.seed(global_seed)
Tim Moon's avatar
Tim Moon committed
102

103

Tim Moon's avatar
Tim Moon committed
104
105
106
107
108
109
def get_fused_attention_backend(
    head_size: int,
    q_seqlen: int,
    kv_seqlen: int,
    dtype: Union[paddle.dtype, str],
    dropout: float,
110
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
111
112
113
114
115
116
117
118
119
120
121
122
123
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
    """Get cuDNN fused attention backend for attention config"""
    if isinstance(dtype, str):
        dtype = dict(
            float32=paddle.float32,
            bfloat16=paddle.bfloat16,
            float16=paddle.float16,
        )[dtype]
    return tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
124
        tex.get_nvte_qkv_layout(qkv_layout),
Tim Moon's avatar
Tim Moon committed
125
126
127
128
129
130
131
132
        AttnBiasType[bias_type],
        AttnMaskType[mask_type],
        dropout,
        q_seqlen,
        kv_seqlen,
        head_size,
    )

133

Tim Moon's avatar
Tim Moon committed
134
135
136
137
138
139
def is_fused_attention_supported(
    head_size: int,
    q_seqlen: int,
    kv_seqlen: int,
    dtype: Union[paddle.dtype, str],
    dropout: float,
140
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> bool:
    """Check if cuDNN fused attention is supported for attention config"""
    backend = get_fused_attention_backend(
        head_size=head_size,
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
        dtype=dtype,
        dropout=dropout,
        qkv_layout=qkv_layout,
        bias_type=bias_type,
        mask_type=mask_type,
    )
    return backend != FusedAttnBackend["No_Backend"]