test_gemini_plugin.py 5.69 KB
Newer Older
1
from contextlib import nullcontext
2
from typing import Optional
3

4
5
6
7
8
9
import torch
import torch.distributed as dist

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
10
from colossalai.fx import is_compatible_with_meta
11
from colossalai.lazy.lazy_init import LazyInitContext
12
from colossalai.nn.optimizer import HybridAdam
13
14
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.shardformer.layer.utils import Randomizer
15
from colossalai.tensor.colo_parameter import ColoParameter
16
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
17
18
19
from tests.kit.model_zoo import model_zoo


20
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
21
    try:
22
        if init_method == "lazy":
23
24
25
            ctx = LazyInitContext()
        else:
            ctx = nullcontext()
26
27
        enable_all_optimization = True if enable_tensor_parallelism else False
        plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
28
29
30
31
32
33
34
35
        booster = Booster(plugin=plugin)
        with ctx:
            model = model_fn()
        optimizer = HybridAdam(model.parameters(), lr=1e-3)
        criterion = lambda x: x.mean()
        data = data_gen_fn()

        data = {
36
            k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
37
38
39
40
41
        }

        model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

        for n, p in model.named_parameters():
42
            assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter"
43
44
45
46
47
48
49
50
51

        output = model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])

        booster.backward(loss, optimizer)
        optimizer.step()

52
53
    except NotImplementedError:
        print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
54
    except Exception as e:
55
        # raise e
56
57
58
59
60
61
62
        return repr(e)


# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])


63
64
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
65
66
@parameterize("enable_tensor_parallelism", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
67
68
69
70
71
    """check gemini plugin over model zoo

    Args:
        early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
    """
72
    is_support_meta = is_compatible_with_meta()
73
    if not is_support_meta and init_method == "lazy":
74
75
        return

76
    passed_models = []
77
    failed_info = {}  # (model_name, error) pair
78

79
    for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
80
        # These models lead to CUDA error
81
82
83
84
85
86
87
88
89
90
91
        if name in (
            "diffusers_auto_encoder_kl",
            "diffusers_vq_model",
            "diffusers_unet2d_model",
            "timm_resmlp",
            "timm_gmixer_12_224",
            "timm_gmlp_b16_224",
            "timm_mixer_b16_224",
            "timm_convnext",
            "torchvision_convnext_base",
        ):
92
93
94
            continue
        # These models are not compatible with gemini
        if name in [
95
96
97
98
99
100
101
102
103
104
            "timm_convit",
            "timm_dm_nfnet",
            "torchvision_vit_b_16",
            "transformers_t5",
            "transformers_t5_for_conditional_generation",
            "transformers_t5_encoder_model",  # does not support apex rmsnorm
            "transformers_chatglm",
            "transformers_sam",
            "transformers_vit",
            "transformers_gpt_double_heads",  # TODO check why does the model fail to run using Gemini
105
106
        ]:
            continue
107

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        if init_method == "lazy" and name in [
            "timm_convmixer",
            "timm_vision_transformer",
            "timm_deit",
            "timm_deit3",
            "timm_inception_v3",
            "timm_tnt_b_patch16_224",
            "timm_rexnet",
            "torchvision_densenet121",
            "torchvision_efficientnet_b0",
            "torchvision_mobilenet_v2",
            "torchvision_mnasnet0_5",
            "torchvision_regnet_x_16gf",
            "torchvision_shufflenet_v2_x0_5",
            "torchvision_efficientnet_v2_s",
123
124
        ]:
            continue
125
126
127
128
129
130

        # TODO debug blip2 when using tp, something wrong with shift_logits's shape
        if "transformers_blip2" in name:
            enable_tensor_parallelism = False

        err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
131
        torch.cuda.empty_cache()
132
133
134
135
        if err is None:
            passed_models.append(name)
        else:
            failed_info[name] = err
136
            if early_stop:
137
                break
138

139
    if dist.get_rank() == 0:
140
141
142
143
        print(f"Init method: {init_method}")
        print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
        print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
    assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
144
145
146
147


def run_dist(rank, world_size, port, early_stop: bool = True):
    # init dist env
148
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
149
150
151
152
153
    check_gemini_plugin(early_stop=early_stop)


@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
154
    spawn(run_dist, 4, early_stop=early_stop)
155
156


157
if __name__ == "__main__":
158
    test_gemini_plugin(early_stop=False)