test_general_checkpoint_io.py 8.11 KB
Newer Older
1
import tempfile
2
import pytest
3
4
5
6
7
import torch
from torch.optim import Adam
from torchvision.models import resnet18

from colossalai.checkpoint_io import GeneralCheckpointIO
8
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
9
from colossalai.testing import clear_cache_before_run, parameterize
10

11
12
13
14
15
16
17
18
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs

19
20
21
22
# ========
# Note:
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
# 2. we will test on both sharded and unsharded checkpoints
23
# 3. implement sharded checkpoint and test it
24
25
26
# ========


27
28
@clear_cache_before_run()
@parameterize('use_safetensors', [True, False])
29
def test_unsharded_checkpoint(use_safetensors: bool):
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    # create a model and optimizer
    model = resnet18()
    optimizer = Adam(model.parameters(), lr=0.001)

    # create test data sample
    x = torch.randn(1, 3, 224, 224)

    # run fwd and bwd
    y = model(x)
    loss = y.sum()
    loss.backward()
    optimizer.step()

    # create a temp file for checkpoint
44
45
46
47
48
    if use_safetensors:
        suffix = ".safetensors"
    else:
        suffix = ".bin"
    model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix)
49
50
51
52
    optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()

    # save the model and optimizer
    ckpt_io = GeneralCheckpointIO()
53
    ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors)
54
55
56
57
58
59
60
    ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)

    # create new model
    new_model = resnet18()
    new_optimizer = Adam(new_model.parameters(), lr=0.001)

    # load the model and optimizer
61
62
    ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
    ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
63
64
65
66
67


    # check for model and optimizer state dict recursively
    recursive_check(model.state_dict(), new_model.state_dict())
    recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

@pytest.mark.parametrize('use_safetensors', [True, False])
def test_sharded_checkpoint(use_safetensors: bool):
    # create a model and optimizer
    model = resnet18()
    optimizer = Adam(model.parameters(), lr=0.001)
    # create test data sample
    x = torch.randn(1, 3, 224, 224)

    # run fwd and bwd
    y = model(x)
    loss = y.sum()
    loss.backward()
    optimizer.step()

    # create a temp file for checkpoint
    if use_safetensors:
        suffix = ".safetensors"
        SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
    else:
        suffix = ".bin"
        WEIGHTS_INDEX_NAME = "model.bin.index.json"
    
    model_ckpt_dir = tempfile.TemporaryDirectory()
    optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()

    # save the model and optimizer
    ckpt_io = GeneralCheckpointIO()

    ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors)
    ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False)
    
    # create new model
    new_model = resnet18()
    new_optimizer = Adam(new_model.parameters(), lr=0.001)

    ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
    ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)

    # check for model and optimizer state dict recursively
    recursive_check(model.state_dict(), new_model.state_dict())
109
    recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['bert'])
@parameterize('use_safetensors', [True, False])
def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool):
    from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification

    model_ckpt_dir = tempfile.TemporaryDirectory()
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, *_ = get_components_func()

    with ColoInitContext(device=get_current_device()):
        bert_model = model_builder()
    bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name)
    config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
    chunk_manager = ChunkManager(config_dict)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    bert_model = ZeroDDP(bert_model, gemini_manager)
    bert_model.train()

    ckpt_io = GeminiCheckpointIO()
    if ckpt_io.coordinator.is_master():
        model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
        ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors)
        new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
        recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict())
    
    model_ckpt_dir.cleanup()
        


@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_safetensors', [True, False])
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, *_ = get_components_func()

    with ColoInitContext(device=get_current_device()):
        model = model_builder()
        new_model = model_builder()

    config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
    chunk_manager = ChunkManager(config_dict)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    model = ZeroDDP(model, gemini_manager)
    model.train()

    new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
    new_chunk_manager = ChunkManager(new_config_dict)
    new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
    new_model = ZeroDDP(new_model, new_gemini_manager)

    model_ckpt_dir = tempfile.TemporaryDirectory()

    ckpt_io = GeminiCheckpointIO()
    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
    ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors)

    # load model
    if ckpt_io.coordinator.is_master():
        ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True)
        model_dict = model.state_dict(only_rank_0=True)
        new_model_dict = new_model.state_dict(only_rank_0=True)
        recursive_check(model_dict, new_model_dict)

    model_ckpt_dir.cleanup()


def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    exam_state_dict()
    hf_load_colossalai_checkpoint()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4, 4])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
    spawn(run_dist, world_size)

192
193
194
195
196
197
198
199
200
201
202
203
204

# do recursive check for the optimizer state dict
# if the value is a dict, compare its values
# if the value is a list, comapre all elements one-by-one
# if the value is a torch.Tensor, use torch.equal
# otherwise use assertEqual
def recursive_check(d1, d2):
    for k, v in d1.items():
        if isinstance(v, dict):
            recursive_check(v, d2[k])
        elif isinstance(v, list):
            for i in range(len(v)):
                if isinstance(v[i], torch.Tensor):
205
206
                    v[i] = v[i].to("cpu")
                    d2[k][i] = d2[k][i].to("cpu")
207
208
209
210
                    assert torch.equal(v[i], d2[k][i])
                else:
                    assert v[i] == d2[k][i]
        elif isinstance(v, torch.Tensor):
211
212
            v = v.to("cpu")
            d2[k] = d2[k].to("cpu")
213
214
215
            assert torch.equal(v, d2[k])
        else:
            assert v == d2[k]