test_gemini_plugin.py 5.86 KB
Newer Older
1
from contextlib import nullcontext
2
from typing import Optional
3
import pytest
4

5
6
7
8
9
10
import torch
import torch.distributed as dist

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
11
from colossalai.fx import is_compatible_with_meta
12
from colossalai.lazy.lazy_init import LazyInitContext
13
from colossalai.nn.optimizer import HybridAdam
14
15
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.shardformer.layer.utils import Randomizer
16
from colossalai.tensor.colo_parameter import ColoParameter
17
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
18
19
20
from tests.kit.model_zoo import model_zoo


21
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
22
    try:
23
        if init_method == "lazy":
24
25
26
            ctx = LazyInitContext()
        else:
            ctx = nullcontext()
27
28
29
        extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
        enable_all_optimization = True if tp_size > 1 else False
        plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
30
31
32
33
34
35
36
37
        booster = Booster(plugin=plugin)
        with ctx:
            model = model_fn()
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
        criterion = lambda x: x.mean()
        data = data_gen_fn()

        data = {
38
            k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
39
40
41
42
43
        }

        model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

        for n, p in model.named_parameters():
44
            assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter"
45
46
47
48
49
50
51
52
53

        output = model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])

        booster.backward(loss, optimizer)
        optimizer.step()

54
55
    except NotImplementedError:
        print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
56
    except Exception as e:
57
        # raise e
58
59
60
61
62
63
64
        return repr(e)


# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])


65
66
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
67
68
69
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
70
71
72
73
74
    """check gemini plugin over model zoo

    Args:
        early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
    """
75
    is_support_meta = is_compatible_with_meta()
76
    if not is_support_meta and init_method == "lazy":
77
78
        return

79
    passed_models = []
80
    failed_info = {}  # (model_name, error) pair
81

82
    for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
83
        # These models lead to CUDA error
84
85
86
87
88
89
90
91
92
93
94
        if name in (
            "diffusers_auto_encoder_kl",
            "diffusers_vq_model",
            "diffusers_unet2d_model",
            "timm_resmlp",
            "timm_gmixer_12_224",
            "timm_gmlp_b16_224",
            "timm_mixer_b16_224",
            "timm_convnext",
            "torchvision_convnext_base",
        ):
95
96
97
            continue
        # These models are not compatible with gemini
        if name in [
98
99
100
101
102
103
104
105
106
107
            "timm_convit",
            "timm_dm_nfnet",
            "torchvision_vit_b_16",
            "transformers_t5",
            "transformers_t5_for_conditional_generation",
            "transformers_t5_encoder_model",  # does not support apex rmsnorm
            "transformers_chatglm",
            "transformers_sam",
            "transformers_vit",
            "transformers_gpt_double_heads",  # TODO check why does the model fail to run using Gemini
108
109
        ]:
            continue
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        if init_method == "lazy" and name in [
            "timm_convmixer",
            "timm_vision_transformer",
            "timm_deit",
            "timm_deit3",
            "timm_inception_v3",
            "timm_tnt_b_patch16_224",
            "timm_rexnet",
            "torchvision_densenet121",
            "torchvision_efficientnet_b0",
            "torchvision_mobilenet_v2",
            "torchvision_mnasnet0_5",
            "torchvision_regnet_x_16gf",
            "torchvision_shufflenet_v2_x0_5",
            "torchvision_efficientnet_v2_s",
126
127
        ]:
            continue
128
129
130

        # TODO debug blip2 when using tp, something wrong with shift_logits's shape
        if "transformers_blip2" in name:
131
            tp_size = 1
132

133
        err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
134
        torch.cuda.empty_cache()
135
136
137
138
        if err is None:
            passed_models.append(name)
        else:
            failed_info[name] = err
139
            if early_stop:
140
                break
141

142
    if dist.get_rank() == 0:
143
144
145
146
        print(f"Init method: {init_method}")
        print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
        print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
    assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
147
148
149
150


def run_dist(rank, world_size, port, early_stop: bool = True):
    # init dist env
151
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
152
153
154
155
156
    check_gemini_plugin(early_stop=early_stop)


@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
157
    spawn(run_dist, 4, early_stop=early_stop)
158

159
160
161
162
163
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
    spawn(run_dist, 8, early_stop=early_stop)

164

165
if __name__ == "__main__":
166
    test_gemini_plugin(early_stop=False)