test_gemini_plugin.py 5.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from functools import partial

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo


def check_gemini_plugin(early_stop: bool = True):
    """check gemini plugin over model zoo

    Args:
        early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
    """
    plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
    booster = Booster(plugin=plugin)

    passed_models = []
    failed_info = {}    # (model_name, error) pair

    for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
        # These models lead to CUDA error
        if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
                    'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
            continue
        # These models are not compatible with gemini
        if name in [
                'diffusers_clip_vision_model',
                'timm_resnet',
                'timm_beit',
                'timm_beitv2',
                'timm_eca_nfnet',
                'timm_efficientformer',
                'timm_hrnet_w18_small',
                'timm_nf_ecaresnet101',
                'timm_nf_regnet_b0',
                'timm_skresnet18',
                'timm_wide_resnet50_2',
                'timm_convit',
                'timm_dm_nfnet',
                'timm_swin_transformer',
                'torchaudio_conformer',
                'torchaudio_deepspeech',
                'torchaudio_wavernn',
                'torchaudio_tacotron',
                'deepfm_interactionarch',
                'deepfm_simpledeepfmnn',
                'dlrm',
                'dlrm_interactionarch',
                'torchvision_googlenet',
                'torchvision_inception_v3',
                'torchvision_mobilenet_v3_small',
                'torchvision_resnet18',
                'torchvision_resnext50_32x4d',
                'torchvision_wide_resnet50_2',
                'torchvision_vit_b_16',
                'torchvision_convnext_base',
                'torchvision_swin_s',
                'transformers_albert',
                'transformers_albert_for_pretraining',
                'transformers_bert',
                'transformers_bert_for_pretraining',
                'transformers_gpt_double_heads',
                'torchaudio_hubert_base',
        ]:
            continue
        try:
            model = model_fn()
            optimizer = HybridAdam(model.parameters(), lr=1e-3)
            criterion = lambda x: x.mean()
            data = data_gen_fn()

            data = {
                k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
                for k, v in data.items()
            }

            model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

            for n, p in model.named_parameters():
                assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'

            output = model(**data)
            output = output_transform_fn(output)
            output_key = list(output.keys())[0]
            loss = criterion(output[output_key])

            booster.backward(loss, optimizer)
            optimizer.step()
            passed_models.append(name)
        except Exception as e:
            failed_info[name] = e
            if early_stop:
                raise e
    if dist.get_rank() == 0:
        print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
        print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
    assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])


def check_dataloader_sharding():
    plugin = GeminiPlugin()

    # create a custom dasetset with 0 to 10
    dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
    train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)

    # get the first batch of data
    batch = next(iter(train_dataloader))[0].cuda()
    is_rank_0 = dist.get_rank() == 0

    if is_rank_0:
        batch_to_compare = batch.clone()
    else:
        batch_to_compare = batch
    # pass to the rank 1 value to rank 0
    dist.broadcast(batch_to_compare, src=1)

    # compare on rank 0
    if is_rank_0:
        assert not torch.equal(batch,
                               batch_to_compare), 'Same number was found across ranks but expected it to be different'


def run_dist(rank, world_size, port, early_stop: bool = True):
    # init dist env
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
    check_dataloader_sharding()
    check_gemini_plugin(early_stop=early_stop)


@pytest.mark.skip(reason='Skip gemini plugin test due to OOM')
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
    world_size = 2
    run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_gemini_plugin(early_stop=False)