test_gemini_checkpoint_io.py 6.93 KB
Newer Older
1
import os
2
3
4

import pytest
import torch
5
import torch.distributed as dist
6
from transformers import LlamaForCausalLM
7
from utils import shared_tempdir
8
9

import colossalai
10
11
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
12
from colossalai.lazy import LazyInitContext
13
from colossalai.nn.optimizer import HybridAdam
14
15
16
17
18
19
20
from colossalai.testing import (
    check_state_dict_equal,
    clear_cache_before_run,
    parameterize,
    rerun_if_address_is_in_use,
    spawn,
)
21
from tests.kit.model_zoo import model_zoo
22

23
MODEL_PLACEMENT_CONFIGS = [
24
25
26
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
27
28
29
]

OPTIM_PLACEMENT_CONFIGS = [
30
31
32
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0},  # zero2-offload
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5},  # zero2-offload-half
33
34
]

35

36
@clear_cache_before_run()
37
38
39
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
40
41
42
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
43
    from transformers import BertForSequenceClassification
44

45
    (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
46
    bert_model = model_fn()
47
    enable_all_optimization = True if tp_size > 1 else False
48

49
    with shared_tempdir() as tempdir:
50
        pretrained_path = os.path.join(tempdir, "pretrained")
51
52
        bert_model.config.save_pretrained(save_directory=pretrained_path)

53
54
        extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
        plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
55
56
        booster = Booster(plugin=plugin)
        bert_model, _, _, _, _ = booster.boost(bert_model)
57
        model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
58

59
60
61
        booster.save_model(
            bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors
        )
62
        dist.barrier()
63

64
        new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
65
        check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
66
67


68
@clear_cache_before_run()
69
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
70
@parameterize("shard", [True, False])
71
72
@parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32])
73
74
75
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
76
    (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
77
    criterion = lambda x: x.mean()
78
79
80
    enable_all_optimization = True if tp_size > 1 else False
    extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
    plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
81
82
83
84
85
86
87
88
89
90
    booster = Booster(plugin=plugin)

    model = model_fn()
    new_model = model_fn()
    optimizer = HybridAdam(model.parameters(), lr=0.001)
    model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
    new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
    new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

    data = data_gen_fn()
91
    data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
92
93
94
95
96
97
98
99
100
101
102
    output = model(**data)
    output = output_transform_fn(output)
    output_key = list(output.keys())[0]
    loss = criterion(output[output_key])

    booster.backward(loss, optimizer)
    optimizer.step()

    with shared_tempdir() as tempdir:
        model_ckpt_path = f"{tempdir}/model"
        optimizer_ckpt_path = f"{tempdir}/optimizer"
103
104
105
        booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)

        booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
106
107
108
        dist.barrier()

        booster.load_model(new_model, model_ckpt_path)
109
110
111
        check_state_dict_equal(
            model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
        )
112
113

        booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
114
115
116
        check_state_dict_equal(
            optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
        )
117
118
119
120

        # Check the new model/optimizer can successfully run.
        data = data_gen_fn()
        data = {
121
            k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
122
123
124
125
126
127
128
129
130
        }
        output = new_model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])
        booster.backward(loss, new_optimizer)
        new_optimizer.step()
        booster.save_model(new_model, model_ckpt_path, shard=shard)
        booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
131
132


133
134
135
136
137
138
139
140
141
142
143
144
145
146
def exam_lazy_from_pretrained():
    llama_path = os.environ["LLAMA_PATH"]
    plugin = GeminiPlugin()
    booster = Booster(plugin=plugin)
    orig_model = LlamaForCausalLM.from_pretrained(llama_path)
    orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
    with LazyInitContext():
        model = LlamaForCausalLM.from_pretrained(llama_path)
    model, *_ = booster.boost(model)
    with shared_tempdir() as tempdir:
        save_path = os.path.join(tempdir, "model.pt")
        booster.save_model(model, save_path, shard=False)
        dist.barrier()
        state_dict = torch.load(save_path, map_location="cpu")
147
        check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
148
149


150
151
def run_dist(rank, world_size, port):
    config = {}
152
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
153
154
    exam_state_dict()
    exam_state_dict_with_origin()
155
    exam_lazy_from_pretrained()
156
157
158


@pytest.mark.dist
159
@pytest.mark.parametrize("world_size", [4])
160
161
162
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
    spawn(run_dist, world_size)
163
164
165
166
167
168

@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO_3d(world_size):
    spawn(run_dist, world_size)