test_pipe_schedule.py 5.73 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
'''Copyright The Microsoft DeepSpeed Team'''

import pytest
import deepspeed.runtime.pipe.schedule as schedule


def _count_type(cmds, classtype):
    return len(list(filter(lambda c: type(c) == classtype, cmds)))


def test_pipe_inference_schedule_singlestage():
    sched = schedule.InferenceSchedule(micro_batches=4, stages=1, stage_id=0)
    assert sched.num_micro_batches == 4
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        assert len(cmds) == 2
        assert type(cmds[0]) == schedule.LoadMicroBatch
        assert type(cmds[1]) == schedule.ForwardPass
        assert cmds[0].buffer_id == cmds[1].buffer_id
    assert len(full) == sched.num_micro_batches


def test_pipe_train_schedule_singlestage():
    sched = schedule.TrainSchedule(micro_batches=4, stages=1, stage_id=0)
    assert sched.num_micro_batches == 4
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        if (idx % 2) != 0:
            assert (len(cmds) == 1) or (len(cmds) == 4)
            assert type(cmds[0]) == schedule.BackwardPass
        else:
            assert len(cmds) == 2
            assert type(cmds[0]) == schedule.LoadMicroBatch
            assert type(cmds[1]) == schedule.ForwardPass
            assert cmds[0].buffer_id == cmds[1].buffer_id
    assert len(full) == sched.num_micro_batches * 2


@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_firststage(micro_batches, stages=3):
    sched = schedule.InferenceSchedule(micro_batches=micro_batches,
                                       stages=stages,
                                       stage_id=0)
    assert sched.num_micro_batches == micro_batches
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        # Ensure we don't send an activation the first step
        if idx == 0:
            assert len(cmds) == 2
            assert type(cmds[0]) == schedule.LoadMicroBatch
            assert type(cmds[1]) == schedule.ForwardPass
            assert cmds[0].buffer_id == cmds[1].buffer_id
            continue

        # the last active step is only a send
        if idx == sched.num_micro_batches:
            assert len(cmds) == 1
            assert type(cmds[0]) == schedule.SendActivation
            continue

        # no work later on
        if idx > sched.num_micro_batches:
            assert len(cmds) == 0
            continue

        # Normally we need to load/forward/send
        assert len(cmds) == 3
        assert _count_type(cmds, schedule.LoadMicroBatch) == 1
        assert _count_type(cmds, schedule.ForwardPass) == 1
        assert _count_type(cmds, schedule.SendActivation) == 1
    assert len(full) == micro_batches + stages - 1


@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_midstage(micro_batches, stages=3):
    sched = schedule.InferenceSchedule(micro_batches=micro_batches,
                                       stages=stages,
                                       stage_id=1)

    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        if idx < sched.stage:
            assert len(cmds) == 0
            continue
        if idx == sched.stage + sched.num_micro_batches:
            assert len(cmds) == 1
            assert type(cmds[0]) == schedule.SendActivation
            continue
        if idx > sched.stage + sched.num_micro_batches:
            assert len(cmds) == 0
            continue
        assert _count_type(cmds, schedule.LoadMicroBatch) == 0
        assert _count_type(cmds, schedule.ForwardPass) == 1
        assert _count_type(cmds, schedule.RecvActivation) == 1
        if idx > sched.stage:
            assert _count_type(cmds, schedule.SendActivation) == 1
    assert len(full) == micro_batches + stages - 1


@pytest.mark.parametrize('micro_batches', [1, 3, 8, 10])
def test_pipe_inference_schedule_laststage(micro_batches, stages=3):
    sched = schedule.InferenceSchedule(micro_batches=micro_batches,
                                       stages=stages,
                                       stage_id=2)
    full = list(iter(sched))
    for idx, cmds in enumerate(full):
        if idx < sched.stage or idx > sched.stage + sched.num_micro_batches:
            assert len(cmds) == 0
            continue
        assert _count_type(cmds, schedule.LoadMicroBatch) == 1
        assert _count_type(cmds, schedule.ForwardPass) == 1
        assert _count_type(cmds, schedule.RecvActivation) == 1
        assert _count_type(cmds, schedule.SendActivation) == 0
    assert len(full) == micro_batches + stages - 1


def test_pipe_schedule_firststage():
    sched = schedule.TrainSchedule(micro_batches=8, stages=3, stage_id=0)
    for cmds in sched:
        assert all(instr.__class__ != schedule.SendGrad for instr in cmds)
        assert all(instr.__class__ != schedule.RecvActivation for instr in cmds)
        for instr in cmds:
            if isinstance(instr, schedule.BufferOpInstruction):
                assert 0 <= instr.buffer_id < sched.num_pipe_buffers()


def test_pipe_schedule_laststage():
    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
    assert len(list(iter(sched))) == 2 * (sched.micro_batches + sched.stages - 1)
    for cmds in sched:
        assert all(instr.__class__ != schedule.SendActivation for instr in cmds)
        assert all(instr.__class__ != schedule.RecvGrad for instr in cmds)


def test_pipe_stagequery():
    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=0)
    assert sched.is_first_stage
    assert not sched.is_last_stage

    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=1)
    assert not sched.is_first_stage
    assert not sched.is_last_stage

    sched = schedule.TrainSchedule(stages=3, micro_batches=4, stage_id=2)
    assert not sched.is_first_stage
    assert sched.is_last_stage