utils.py 4.65 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils for testing"""

6
import random
Tim Moon's avatar
Tim Moon committed
7
from typing import Union
8

Tim Moon's avatar
Tim Moon committed
9
import numpy as np
10
import paddle
11
12
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
13

14
import transformer_engine    # pylint: disable=unused-import
Tim Moon's avatar
Tim Moon committed
15
16
17
18
19
20
21
from transformer_engine.paddle.constants import (
    TE_DType,
    QKVLayout,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
22
from transformer_engine.paddle.fp8 import FP8TensorMeta
Tim Moon's avatar
Tim Moon committed
23
import transformer_engine_paddle as tex
24

25
26

def create_fp8_meta(num_gemms=1, amax_history_len=10):
27
28
29
    """
    Create and initialize FP8TensorMeta
    """
30
31
    fp8_meta = FP8TensorMeta(is_forward=True)
    fp8_meta.prepare(num_gemms, amax_history_len)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    return fp8_meta


def assert_allclose(actual,
                    desired,
                    rtol=1e-05,
                    atol=1e-08,
                    equal_nan=True,
                    err_msg='',
                    verbose=True):
    """Compare two input paddle tensors"""
    if isinstance(actual, paddle.Tensor):
        actual = paddle.cast(actual, 'float32').numpy()
    if isinstance(desired, paddle.Tensor):
        desired = paddle.cast(desired, 'float32').numpy()
    np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def assert_shape(inp, expected_shape):
    """Assert the shape of input tensor equals to expected shape"""
    assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
        f"actual tensor shape: {inp.shape}"


def is_devices_enough(required):
    """If the number of device is enough"""
    return paddle.device.cuda.device_count() >= required


def set_random_seed(seed):
    """Set random seed for reproducability."""
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    hcg = fleet.get_hybrid_communicate_group()
    if paddle.distributed.get_world_size() > 1:
        # obtain rank message of hybrid parallel

        mp_rank = hcg.get_model_parallel_rank()
        mp_size = hcg.get_model_parallel_world_size()

        pp_rank = hcg.get_stage_id()
        pp_size = hcg.get_pipe_parallel_world_size()

        dp_rank = hcg.get_data_parallel_rank()
        dp_size = hcg.get_data_parallel_world_size()

        sharding_rank = hcg.get_sharding_parallel_rank()
    else:
        mp_rank, mp_size = 0, 1
        pp_rank, pp_size = 0, 1
        dp_rank, dp_size = 0, 1
        sharding_rank, _ = 0, 1

    random.seed(seed + 100 * pp_rank)
    np.random.seed(seed + 100 * pp_rank)

    seed_offset = seed + 1024 + paddle.distributed.get_world_size()
    global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                   sharding_rank * (mp_size * pp_size * dp_size))

    seed_offset += paddle.distributed.get_world_size()
    local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                  sharding_rank * (mp_size * pp_size * dp_size))

    tracker = get_rng_state_tracker()
    # tracker.reset()
    if "global_seed" not in tracker.states_:
        tracker.add("global_seed", global_seed)
    if "local_seed" not in tracker.states_:
        tracker.add("local_seed", local_seed)

    paddle.seed(global_seed)
Tim Moon's avatar
Tim Moon committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

def get_fused_attention_backend(
    head_size: int,
    q_seqlen: int,
    kv_seqlen: int,
    dtype: Union[paddle.dtype, str],
    dropout: float,
    qkv_layout: str = "qkv_interleaved",
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
    """Get cuDNN fused attention backend for attention config"""
    if isinstance(dtype, str):
        dtype = dict(
            float32=paddle.float32,
            bfloat16=paddle.bfloat16,
            float16=paddle.float16,
        )[dtype]
    return tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
        QKVLayout[qkv_layout],
        AttnBiasType[bias_type],
        AttnMaskType[mask_type],
        dropout,
        q_seqlen,
        kv_seqlen,
        head_size,
    )

def is_fused_attention_supported(
    head_size: int,
    q_seqlen: int,
    kv_seqlen: int,
    dtype: Union[paddle.dtype, str],
    dropout: float,
    qkv_layout: str = "qkv_interleaved",
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> bool:
    """Check if cuDNN fused attention is supported for attention config"""
    backend = get_fused_attention_backend(
        head_size=head_size,
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
        dtype=dtype,
        dropout=dropout,
        qkv_layout=qkv_layout,
        bias_type=bias_type,
        mask_type=mask_type,
    )
    return backend != FusedAttnBackend["No_Backend"]