test_pld.py 3.69 KB
Newer Older
Olatunji Ruwase's avatar
Olatunji Ruwase committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import deepspeed
import pytest
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from common import distributed_test
from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict


@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0])
def test_pld_schedule(tmpdir, theta):
    gamma = 0.001

    pld_scheduler = ProgressiveLayerDrop(theta, gamma)
    for i in range(10):
        pld_scheduler.update_state(i)
        expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
        actual_theta = pld_scheduler.get_theta()
        assert expected_theta == actual_theta


@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0])
def test_pld_model(tmpdir, theta):
    gamma = 0.001
    config_dict = {
        "train_batch_size": 1,
        "steps_per_print": 1,
        "optimizer": {
            "type": 'Adam',
            "params": {
                "lr": 0.0001
            }
        },
        "fp16": {
            "enabled": True
        },
        "progressive_layer_drop": {
            "enabled": True,
            "theta": theta,
            "gamma": gamma
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = PLD_SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[1])
    def _test_pld_model(args, model, hidden_dim, theta, gamma):
        model, _, _, _ = deepspeed.initialize(args=args,
                                              model=model,
                                              model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

            expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
            actual_theta = model.get_pld_theta()
            assert expected_theta == actual_theta

    _test_pld_model(args=args,
                    model=model,
                    hidden_dim=hidden_dim,
                    theta=theta,
                    gamma=gamma)


def test_non_pld_model(tmpdir):
    gamma = 0.001
    theta = 0.5
    config_dict = {
        "train_batch_size": 1,
        "steps_per_print": 1,
        "optimizer": {
            "type": 'Adam',
            "params": {
                "lr": 0.0001
            }
        },
        "fp16": {
            "enabled": True
        },
        "progressive_layer_drop": {
            "enabled": True,
            "theta": theta,
            "gamma": gamma
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[1])
    def _test_non_pld_model(args, model, hidden_dim):
        model, _, _, _ = deepspeed.initialize(args=args,
                                              model=model,
                                              model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=1,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            with pytest.raises(TypeError):
                loss = model(batch[0], batch[1])

    _test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim)