test_master_grad.py 2.59 KB
Newer Older
Shijie's avatar
Shijie committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TransformerLayer encoder main_grad"""

import numpy as np
import pytest

import paddle
from paddle.distributed.fleet.utils import mix_precision_utils

import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available

is_fp8_supported, reason = is_fp8_available()


def create_optimizer(model, use_pure_bf16, use_main_grad):
19
    """Create optimizer"""
Shijie's avatar
Shijie committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    if use_main_grad:
        assert use_pure_bf16
        model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
    optimizer = paddle.optimizer.AdamW(
        parameters=model.parameters(),
        learning_rate=0.0001,
        multi_precision=use_pure_bf16,
    )
    if use_main_grad:
        optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer)

    return optimizer


class Net(paddle.nn.Layer):
35
    """Network use for main_grad testing"""
Shijie's avatar
Shijie committed
36
37
38
39
40
41
42

    def __init__(self, fuse_wgrad_accumulation):
        super().__init__()
        self.layer = te.TransformerLayer(
            4096,
            16384,
            32,
43
            layer_type="encoder",
Shijie's avatar
Shijie committed
44
45
46
47
48
49
50
51
52
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
        )

    def forward(self, inp):
        out = self.layer(inp)
        return out


def train(enable_master_grad, fuse_wgrad_accumulation=False):
53
    """Train function"""
Shijie's avatar
Shijie committed
54
55
56
57
58
59
60
61
62
63
64
65
66
    paddle.seed(10)

    accumulate_steps = 4

    if fuse_wgrad_accumulation:
        assert enable_master_grad, "fuse_wgrad_accumulation requires enable_master_grad"

    model = Net(fuse_wgrad_accumulation)

    optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=enable_master_grad)

    loss_list = []
    for step_id in range(16):
67
        inp = paddle.uniform([2, 1024, 4096], dtype="float32")
Shijie's avatar
Shijie committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        inp.stop_gradient = False
        with te.fp8_autocast(enabled=True):
            out = model(inp)
        loss = out.mean()
        loss_list.append(loss)
        loss.backward()

        # gradient accumulation
        if (step_id + 1) % accumulate_steps == 0:
            optimizer.step()
            optimizer.clear_grad()

    return loss_list


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_master_grad():
85
86
    """Test main_grad"""
    paddle.set_default_dtype("float32")
Shijie's avatar
Shijie committed
87
88
89
90
91
92
    loss1 = train(enable_master_grad=False)
    loss2 = train(enable_master_grad=True)
    loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True)

    np.testing.assert_allclose(loss1, loss2, rtol=1e-5, atol=1e-5)
    np.testing.assert_allclose(loss1, loss3, rtol=1e-5, atol=1e-5)