test_gemini_plugin.py 6.37 KB
Newer Older
1
from contextlib import nullcontext
2
from typing import Optional
3

4
import pytest
5
6
7
8
9
10
import torch
import torch.distributed as dist

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
11
from colossalai.fx import is_compatible_with_meta
12
from colossalai.lazy.lazy_init import LazyInitContext
13
14
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
15
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
Frank Lee's avatar
Frank Lee committed
16
from tests.kit.model_zoo import model_zoo, COMMON_MODELS, IS_FAST_TEST
17
18


19
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
20
    try:
21
        if init_method == "lazy":
22
23
24
            ctx = LazyInitContext()
        else:
            ctx = nullcontext()
25
26
        extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
        enable_all_optimization = True if tp_size > 1 else False
27
28
29
30
31
32
33
        plugin = GeminiPlugin(
            max_norm=1.0,
            initial_scale=2**5,
            tp_size=tp_size,
            extra_dp_size=extra_dp_size,
            enable_all_optimization=enable_all_optimization,
        )
34
35
36
37
38
39
40
41
        booster = Booster(plugin=plugin)
        with ctx:
            model = model_fn()
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
        criterion = lambda x: x.mean()
        data = data_gen_fn()

        data = {
42
            k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
43
44
45
46
47
        }

        model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

        for n, p in model.named_parameters():
48
            assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter"
49
50
51
52
53
54
55
56
57

        output = model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])

        booster.backward(loss, optimizer)
        optimizer.step()

58
59
    except NotImplementedError:
        print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
60
    except Exception as e:
61
        # raise e
62
63
64
65
66
67
68
        return repr(e)


# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])


Frank Lee's avatar
Frank Lee committed
69
@parameterize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "transformers", "diffusers"])
70
@parameterize("init_method", ["none"])
71
72
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
73
74
75
def check_gemini_plugin(
    subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1
):
76
77
78
79
80
    """check gemini plugin over model zoo

    Args:
        early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
    """
81
    is_support_meta = is_compatible_with_meta()
82
    if not is_support_meta and init_method == "lazy":
83
84
        return

85
    passed_models = []
86
    failed_info = {}  # (model_name, error) pair
87

88
    for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
89
        # These models lead to CUDA error
90
91
92
93
94
95
96
97
98
99
100
        if name in (
            "diffusers_auto_encoder_kl",
            "diffusers_vq_model",
            "diffusers_unet2d_model",
            "timm_resmlp",
            "timm_gmixer_12_224",
            "timm_gmlp_b16_224",
            "timm_mixer_b16_224",
            "timm_convnext",
            "torchvision_convnext_base",
        ):
101
102
103
            continue
        # These models are not compatible with gemini
        if name in [
104
105
106
107
108
109
110
111
112
113
            "timm_convit",
            "timm_dm_nfnet",
            "torchvision_vit_b_16",
            "transformers_t5",
            "transformers_t5_for_conditional_generation",
            "transformers_t5_encoder_model",  # does not support apex rmsnorm
            "transformers_chatglm",
            "transformers_sam",
            "transformers_vit",
            "transformers_gpt_double_heads",  # TODO check why does the model fail to run using Gemini
114
115
116
117
118
            "transformers_falcon",  # TODO check why falcon fails to run Gemini
            "transformers_falcon_for_causal_lm",
            "transformers_falcon_for_sequence_classification",
            "transformers_falcon_for_token_classification",
            "transformers_falcon_for_question_answering",
119
120
121
            "transformers_gptj_lm", # lead to OOM when running in ci
            "transformers_gptj_for_question_answering",
            "transformers_gptj_for_sequence_classification",
122
123
        ]:
            continue
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        if init_method == "lazy" and name in [
            "timm_convmixer",
            "timm_vision_transformer",
            "timm_deit",
            "timm_deit3",
            "timm_inception_v3",
            "timm_tnt_b_patch16_224",
            "timm_rexnet",
            "torchvision_densenet121",
            "torchvision_efficientnet_b0",
            "torchvision_mobilenet_v2",
            "torchvision_mnasnet0_5",
            "torchvision_regnet_x_16gf",
            "torchvision_shufflenet_v2_x0_5",
            "torchvision_efficientnet_v2_s",
140
141
        ]:
            continue
142
143
144

        # TODO debug blip2 when using tp, something wrong with shift_logits's shape
        if "transformers_blip2" in name:
145
            tp_size = 1
146

147
        err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
148
        torch.cuda.empty_cache()
149
150
151
152
        if err is None:
            passed_models.append(name)
        else:
            failed_info[name] = err
153
            if early_stop:
154
                break
155

156
    if dist.get_rank() == 0:
157
158
159
160
        print(f"Init method: {init_method}")
        print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
        print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
    assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
161
162
163
164


def run_dist(rank, world_size, port, early_stop: bool = True):
    # init dist env
165
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
166
167
168
169
170
    check_gemini_plugin(early_stop=early_stop)


@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
171
    spawn(run_dist, 4, early_stop=early_stop)
172

173

174
175
176
177
178
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
    spawn(run_dist, 8, early_stop=early_stop)

179

180
if __name__ == "__main__":
181
    test_gemini_plugin(early_stop=False)