"transformer_engine/pytorch/cpu_offload.py" did not exist on "bacefdbb6815159c42d5ca501a1400697b98a1e3"
utils.py 4.86 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Utils for testing"""

6
import random
Tim Moon's avatar
Tim Moon committed
7
from typing import Union
8

Tim Moon's avatar
Tim Moon committed
9
import numpy as np
10
import paddle
11
12
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
13

14
import transformer_engine    # pylint: disable=unused-import
Tim Moon's avatar
Tim Moon committed
15
16
17
18
19
20
from transformer_engine.paddle.constants import (
    TE_DType,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
)
21
from transformer_engine.paddle.fp8 import FP8TensorMeta
22
import transformer_engine_paddle as tex    # pylint: disable=wrong-import-order
23

24
25

def create_fp8_meta(num_gemms=1, amax_history_len=10):
26
27
28
    """
    Create and initialize FP8TensorMeta
    """
29
30
    fp8_meta = FP8TensorMeta(is_forward=True)
    fp8_meta.prepare(num_gemms, amax_history_len)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    return fp8_meta


def assert_allclose(actual,
                    desired,
                    rtol=1e-05,
                    atol=1e-08,
                    equal_nan=True,
                    err_msg='',
                    verbose=True):
    """Compare two input paddle tensors"""
    if isinstance(actual, paddle.Tensor):
        actual = paddle.cast(actual, 'float32').numpy()
    if isinstance(desired, paddle.Tensor):
        desired = paddle.cast(desired, 'float32').numpy()
    np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def assert_shape(inp, expected_shape):
    """Assert the shape of input tensor equals to expected shape"""
    assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
        f"actual tensor shape: {inp.shape}"


def is_devices_enough(required):
    """If the number of device is enough"""
    return paddle.device.cuda.device_count() >= required


def set_random_seed(seed):
    """Set random seed for reproducability."""
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    hcg = fleet.get_hybrid_communicate_group()
    if paddle.distributed.get_world_size() > 1:
        # obtain rank message of hybrid parallel

        mp_rank = hcg.get_model_parallel_rank()
        mp_size = hcg.get_model_parallel_world_size()

        pp_rank = hcg.get_stage_id()
        pp_size = hcg.get_pipe_parallel_world_size()

        dp_rank = hcg.get_data_parallel_rank()
        dp_size = hcg.get_data_parallel_world_size()

        sharding_rank = hcg.get_sharding_parallel_rank()
    else:
        mp_rank, mp_size = 0, 1
        pp_rank, pp_size = 0, 1
        dp_rank, dp_size = 0, 1
        sharding_rank, _ = 0, 1

    random.seed(seed + 100 * pp_rank)
    np.random.seed(seed + 100 * pp_rank)

    seed_offset = seed + 1024 + paddle.distributed.get_world_size()
    global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                   sharding_rank * (mp_size * pp_size * dp_size))

    seed_offset += paddle.distributed.get_world_size()
    local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                  sharding_rank * (mp_size * pp_size * dp_size))

    tracker = get_rng_state_tracker()
    # tracker.reset()
    if "global_seed" not in tracker.states_:
        tracker.add("global_seed", global_seed)
    if "local_seed" not in tracker.states_:
        tracker.add("local_seed", local_seed)

    paddle.seed(global_seed)
Tim Moon's avatar
Tim Moon committed
102

103

Tim Moon's avatar
Tim Moon committed
104
def get_fused_attention_backend(
105
106
    num_heads: int,
    num_gqa_groups: int,
Tim Moon's avatar
Tim Moon committed
107
108
    q_seqlen: int,
    kv_seqlen: int,
109
    head_size: int,
Tim Moon's avatar
Tim Moon committed
110
111
    dtype: Union[paddle.dtype, str],
    dropout: float,
112
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
113
114
115
116
117
118
119
120
121
122
123
124
125
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
    """Get cuDNN fused attention backend for attention config"""
    if isinstance(dtype, str):
        dtype = dict(
            float32=paddle.float32,
            bfloat16=paddle.bfloat16,
            float16=paddle.float16,
        )[dtype]
    return tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
126
        tex.get_nvte_qkv_layout(qkv_layout),
Tim Moon's avatar
Tim Moon committed
127
128
129
        AttnBiasType[bias_type],
        AttnMaskType[mask_type],
        dropout,
130
131
        num_heads,
        num_gqa_groups,
Tim Moon's avatar
Tim Moon committed
132
133
134
135
136
        q_seqlen,
        kv_seqlen,
        head_size,
    )

137

Tim Moon's avatar
Tim Moon committed
138
def is_fused_attention_supported(
139
140
    num_heads: int,
    num_gqa_groups: int,
Tim Moon's avatar
Tim Moon committed
141
142
    q_seqlen: int,
    kv_seqlen: int,
143
    head_size: int,
Tim Moon's avatar
Tim Moon committed
144
145
    dtype: Union[paddle.dtype, str],
    dropout: float,
146
    qkv_layout: str = "bs3hd",
Tim Moon's avatar
Tim Moon committed
147
148
149
150
151
    bias_type: str = "no_bias",
    mask_type: str = "causal",
) -> bool:
    """Check if cuDNN fused attention is supported for attention config"""
    backend = get_fused_attention_backend(
152
153
        num_heads=num_heads,
        num_gqa_groups=num_gqa_groups,
Tim Moon's avatar
Tim Moon committed
154
155
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
156
        head_size=head_size,
Tim Moon's avatar
Tim Moon committed
157
158
159
160
161
162
163
        dtype=dtype,
        dropout=dropout,
        qkv_layout=qkv_layout,
        bias_type=bias_type,
        mask_type=mask_type,
    )
    return backend != FusedAttnBackend["No_Backend"]