test_partition.py 1.1 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
8
9
import os.path as osp

import pytest
import torch
from torch.utils.data import DataLoader

from colossalai.builder import build_dataset, ModelInitializer
from colossalai.core import global_context
from colossalai.initialize import init_dist
Frank Lee's avatar
Frank Lee committed
10
from colossalai.logging import get_dist_logger
zbian's avatar
zbian committed
11
12
13
14
15
16
17
18
19

DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')


@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_partition():
    init_dist(CONFIG_PATH)
Frank Lee's avatar
Frank Lee committed
20
    logger = get_dist_logger()
zbian's avatar
zbian committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    logger.info('finished initialization')

    # build model
    model = ModelInitializer(global_context.config.model, 1, verbose=True).model_initialize()
    logger.info('model is created')

    dataset = build_dataset(global_context.config.train_data.dataset)
    dataloader = DataLoader(dataset=dataset, **global_context.config.train_data.dataloader)
    logger.info('train data is created')

    global_context.destroy()
    torch.cuda.synchronize()
    logger.info('training finished')


if __name__ == '__main__':
    test_partition()