test_schedule.py 1.24 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os.path as osp

import pytest

8
from colossalai.context import ParallelMode
zbian's avatar
zbian committed
9
10
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
Frank Lee's avatar
Frank Lee committed
11
from colossalai.logging import get_dist_logger
zbian's avatar
zbian committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25

NUM_BATCH = 128

BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512

DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')


@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_schedule():
26
    engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
Frank Lee's avatar
Frank Lee committed
27
    logger = get_dist_logger()
zbian's avatar
zbian committed
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    model = engine.model
    optimizer = engine.optimizer
    criterion = engine.criterion
    schedule = engine._schedule

    output, label, loss = schedule.forward_backward_step(
        data_iter=iter(train_dataloader),
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        forward_only=False
    )
    schedule.optimizer_step(model, optimizer)

    if gpc.is_last_rank(ParallelMode.PIPELINE):
        logger.info('losses: {}'.format(loss))
zbian's avatar
zbian committed
45
46
47
48
49
50
51

    gpc.destroy()
    logger.info('training finished')


if __name__ == '__main__':
    test_schedule()