".github/workflows/auto_release_bdist.yml" did not exist on "7c2634f4b3c8ae5eeb1b89ad1c8530233dee5c92"
test_gemini_plugin.py 4.71 KB
Newer Older
1
from contextlib import nullcontext
2
from typing import Optional
3

4
5
6
7
8
9
import torch
import torch.distributed as dist

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
10
from colossalai.fx import is_compatible_with_meta
11
from colossalai.lazy.lazy_init import LazyInitContext
12
13
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
14
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
15
16
17
from tests.kit.model_zoo import model_zoo


18
19
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
    try:
20
        if init_method == 'lazy':
21
22
23
            ctx = LazyInitContext()
        else:
            ctx = nullcontext()
24
        plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        booster = Booster(plugin=plugin)
        with ctx:
            model = model_fn()
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
        criterion = lambda x: x.mean()
        data = data_gen_fn()

        data = {
            k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
        }

        model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

        for n, p in model.named_parameters():
            assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'

        output = model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])

        booster.backward(loss, optimizer)
        optimizer.step()

    except Exception as e:
50
        # raise e
51
52
53
54
55
56
57
        return repr(e)


# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])


58
@parameterize('subset', ['torchvision', 'transformers', 'diffusers'])
59
@parameterize('init_method', ['none'])
60
def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool = True):
61
62
63
64
65
    """check gemini plugin over model zoo

    Args:
        early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
    """
66
67
68
69
    is_support_meta = is_compatible_with_meta()
    if not is_support_meta and init_method == 'lazy':
        return

70
71
72
    passed_models = []
    failed_info = {}    # (model_name, error) pair

73
    for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
74
75
        # These models lead to CUDA error
        if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
76
77
                    'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
                    'torchvision_convnext_base'):
78
79
80
            continue
        # These models are not compatible with gemini
        if name in [
81
82
83
84
85
86
87
88
89
                'timm_convit',
                'timm_dm_nfnet',
                'torchvision_vit_b_16',
                'transformers_t5',
                'transformers_t5_for_conditional_generation',
                'transformers_t5_encoder_model',    # does not support apex rmsnorm
                'transformers_chatglm',
                'transformers_sam',
                'transformers_vit'
90
91
        ]:
            continue
92

93
94
95
96
97
98
99
        if init_method == 'lazy' and name in [
                'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3',
                'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0',
                'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf',
                'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s'
        ]:
            continue
100
        err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
101
        torch.cuda.empty_cache()
102
103
104
105
        if err is None:
            passed_models.append(name)
        else:
            failed_info[name] = err
106
            if early_stop:
107
                break
108

109
    if dist.get_rank() == 0:
110
        print(f'Init method: {init_method}')
111
112
113
114
115
116
117
118
119
120
121
122
123
        print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
        print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
    assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])


def run_dist(rank, world_size, port, early_stop: bool = True):
    # init dist env
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
    check_gemini_plugin(early_stop=early_stop)


@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
124
    spawn(run_dist, 4, early_stop=early_stop)
125
126
127
128


if __name__ == '__main__':
    test_gemini_plugin(early_stop=False)