test_gemini_plugin.py 4.95 KB
Newer Older
1
from contextlib import nullcontext
2
from typing import Optional
3

4
5
6
7
8
9
import torch
import torch.distributed as dist

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
10
from colossalai.fx import is_compatible_with_meta
11
from colossalai.lazy.lazy_init import LazyInitContext
12
13
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
14
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
15
16
17
from tests.kit.model_zoo import model_zoo


18
19
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
    try:
20
        if init_method == "lazy":
21
22
23
            ctx = LazyInitContext()
        else:
            ctx = nullcontext()
24
        plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
25
26
27
28
29
30
31
32
        booster = Booster(plugin=plugin)
        with ctx:
            model = model_fn()
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
        criterion = lambda x: x.mean()
        data = data_gen_fn()

        data = {
33
            k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
34
35
36
37
38
        }

        model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

        for n, p in model.named_parameters():
39
            assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter"
40
41
42
43
44
45
46
47
48
49

        output = model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])

        booster.backward(loss, optimizer)
        optimizer.step()

    except Exception as e:
50
        # raise e
51
52
53
54
55
56
57
        return repr(e)


# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])


58
59
60
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True):
61
62
63
64
65
    """check gemini plugin over model zoo

    Args:
        early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
    """
66
    is_support_meta = is_compatible_with_meta()
67
    if not is_support_meta and init_method == "lazy":
68
69
        return

70
    passed_models = []
71
    failed_info = {}  # (model_name, error) pair
72

73
    for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
74
        # These models lead to CUDA error
75
76
77
78
79
80
81
82
83
84
85
        if name in (
            "diffusers_auto_encoder_kl",
            "diffusers_vq_model",
            "diffusers_unet2d_model",
            "timm_resmlp",
            "timm_gmixer_12_224",
            "timm_gmlp_b16_224",
            "timm_mixer_b16_224",
            "timm_convnext",
            "torchvision_convnext_base",
        ):
86
87
88
            continue
        # These models are not compatible with gemini
        if name in [
89
90
91
92
93
94
95
96
97
98
            "timm_convit",
            "timm_dm_nfnet",
            "torchvision_vit_b_16",
            "transformers_t5",
            "transformers_t5_for_conditional_generation",
            "transformers_t5_encoder_model",  # does not support apex rmsnorm
            "transformers_chatglm",
            "transformers_sam",
            "transformers_vit",
            "transformers_gpt_double_heads",  # TODO check why does the model fail to run using Gemini
99
100
        ]:
            continue
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        if init_method == "lazy" and name in [
            "timm_convmixer",
            "timm_vision_transformer",
            "timm_deit",
            "timm_deit3",
            "timm_inception_v3",
            "timm_tnt_b_patch16_224",
            "timm_rexnet",
            "torchvision_densenet121",
            "torchvision_efficientnet_b0",
            "torchvision_mobilenet_v2",
            "torchvision_mnasnet0_5",
            "torchvision_regnet_x_16gf",
            "torchvision_shufflenet_v2_x0_5",
            "torchvision_efficientnet_v2_s",
117
118
        ]:
            continue
119
        err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
120
        torch.cuda.empty_cache()
121
122
123
124
        if err is None:
            passed_models.append(name)
        else:
            failed_info[name] = err
125
            if early_stop:
126
                break
127

128
    if dist.get_rank() == 0:
129
130
131
132
        print(f"Init method: {init_method}")
        print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
        print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
    assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
133
134
135
136


def run_dist(rank, world_size, port, early_stop: bool = True):
    # init dist env
137
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
138
139
140
141
142
    check_gemini_plugin(early_stop=early_stop)


@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
143
    spawn(run_dist, 4, early_stop=early_stop)
144
145


146
if __name__ == "__main__":
147
    test_gemini_plugin(early_stop=False)