test_state_dict.py 2.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from copy import deepcopy
from functools import partial

import colossalai
import pytest
import torch
import torch.multiprocessing as mp
ver217's avatar
ver217 committed
11
from colossalai.testing import parameterize
12
from colossalai.utils import free_port
13
from colossalai.zero.init_ctx import ZeroInitContext
ver217's avatar
ver217 committed
14
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
15
from colossalai.zero.sharded_model import ShardedModelV2
16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
17
from tests.components_to_test.registry import non_distributed_component_funcs
ver217's avatar
ver217 committed
18

19
20
21
from common import CONFIG


ver217's avatar
ver217 committed
22
23
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_zero_state_dict(shard_strategy_class):
24
    test_models = ['repeated_computed_layers', 'resnet18']
ver217's avatar
ver217 committed
25
    shard_strategy = shard_strategy_class()
26
27
    for model_name in test_models:
        get_components_func = non_distributed_component_funcs.get_callable(model_name)
28
        model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
29
30
31
32
33
34
35
36
37
38
39
40
41

        with ZeroInitContext(convert_fp16=True,
                             target_device=torch.cuda.current_device(),
                             shard_strategy=shard_strategy,
                             shard_param=True,
                             rm_torch_payload_on_the_fly=False):
            zero_model = model_builder(checkpoint=True)
        zero_model = ShardedModelV2(zero_model, shard_strategy)

        model = model_builder(checkpoint=True).half()
        col_model_deepcopy(zero_model, model)
        model = model.cuda()

42
43
44
45
46
        zero_state_dict = zero_model.state_dict()
        for key, val in model.state_dict().items():
            assert torch.equal(val, zero_state_dict[key])


47
48
49
50
51
def run_dist(rank, world_size, port):
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    run_zero_state_dict()


52
@pytest.mark.dist
ver217's avatar
ver217 committed
53
@pytest.mark.parametrize("world_size", [1, 2])
54
55
def test_zero_state_dict(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
56
57
58
59
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
ver217's avatar
ver217 committed
60
    test_zero_state_dict(2, TensorShardStrategy)