utils.py 3.02 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils for testing"""

6
import random
7
import numpy as np
8
9

import paddle
10
11
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
12

13
14
import transformer_engine    # pylint: disable=unused-import

15
from transformer_engine.paddle.fp8 import FP8TensorMeta
16

17
18

def create_fp8_meta(num_gemms=1, amax_history_len=10):
19
20
21
    """
    Create and initialize FP8TensorMeta
    """
22
23
    fp8_meta = FP8TensorMeta(is_forward=True)
    fp8_meta.prepare(num_gemms, amax_history_len)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    return fp8_meta


def assert_allclose(actual,
                    desired,
                    rtol=1e-05,
                    atol=1e-08,
                    equal_nan=True,
                    err_msg='',
                    verbose=True):
    """Compare two input paddle tensors"""
    if isinstance(actual, paddle.Tensor):
        actual = paddle.cast(actual, 'float32').numpy()
    if isinstance(desired, paddle.Tensor):
        desired = paddle.cast(desired, 'float32').numpy()
    np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54


def assert_shape(inp, expected_shape):
    """Assert the shape of input tensor equals to expected shape"""
    assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
        f"actual tensor shape: {inp.shape}"


def is_devices_enough(required):
    """If the number of device is enough"""
    return paddle.device.cuda.device_count() >= required


def set_random_seed(seed):
    """Set random seed for reproducability."""
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    hcg = fleet.get_hybrid_communicate_group()
    if paddle.distributed.get_world_size() > 1:
        # obtain rank message of hybrid parallel

        mp_rank = hcg.get_model_parallel_rank()
        mp_size = hcg.get_model_parallel_world_size()

        pp_rank = hcg.get_stage_id()
        pp_size = hcg.get_pipe_parallel_world_size()

        dp_rank = hcg.get_data_parallel_rank()
        dp_size = hcg.get_data_parallel_world_size()

        sharding_rank = hcg.get_sharding_parallel_rank()
    else:
        mp_rank, mp_size = 0, 1
        pp_rank, pp_size = 0, 1
        dp_rank, dp_size = 0, 1
        sharding_rank, _ = 0, 1

    random.seed(seed + 100 * pp_rank)
    np.random.seed(seed + 100 * pp_rank)

    seed_offset = seed + 1024 + paddle.distributed.get_world_size()
    global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                   sharding_rank * (mp_size * pp_size * dp_size))

    seed_offset += paddle.distributed.get_world_size()
    local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
                  sharding_rank * (mp_size * pp_size * dp_size))

    tracker = get_rng_state_tracker()
    # tracker.reset()
    if "global_seed" not in tracker.states_:
        tracker.add("global_seed", global_seed)
    if "local_seed" not in tracker.states_:
        tracker.add("local_seed", local_seed)

    paddle.seed(global_seed)