"docs/README-zh-Hans.md" did not exist on "fb3546059539774b54442226ea864bf5d84e9532"
test_gemini_checkpoint_io.py 6.91 KB
Newer Older
1
import os
2
3
4

import pytest
import torch
5
import torch.distributed as dist
6
from transformers import LlamaForCausalLM
7
from utils import shared_tempdir
8
9

import colossalai
Frank Lee's avatar
Frank Lee committed
10
from colossalai.testing import skip_if_not_enough_gpus
11
12
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
13
from colossalai.lazy import LazyInitContext
14
from colossalai.nn.optimizer import HybridAdam
15
16
17
18
19
20
21
from colossalai.testing import (
    check_state_dict_equal,
    clear_cache_before_run,
    parameterize,
    rerun_if_address_is_in_use,
    spawn,
)
22
from tests.kit.model_zoo import model_zoo
23

24
MODEL_PLACEMENT_CONFIGS = [
25
26
27
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
28
29
30
]

OPTIM_PLACEMENT_CONFIGS = [
31
32
33
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0},  # zero2-offload
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5},  # zero2-offload-half
34
35
]

36

37
@clear_cache_before_run()
38
39
40
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
41
42
43
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
44
    from transformers import BertForSequenceClassification
45

46
    (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
47
    bert_model = model_fn()
48
    enable_all_optimization = True if tp_size > 1 else False
49

50
    with shared_tempdir() as tempdir:
51
        pretrained_path = os.path.join(tempdir, "pretrained")
52
53
        bert_model.config.save_pretrained(save_directory=pretrained_path)

54
55
        extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
        plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
56
57
        booster = Booster(plugin=plugin)
        bert_model, _, _, _, _ = booster.boost(bert_model)
58
        model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
59

60
61
62
        booster.save_model(
            bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors
        )
63
        dist.barrier()
64

65
        new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
66
        check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
67
68


69
@clear_cache_before_run()
70
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
71
@parameterize("shard", [True, False])
Frank Lee's avatar
Frank Lee committed
72
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
73
@parameterize("size_per_shard", [32])
74
75
76
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
77
    (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
78
    criterion = lambda x: x.mean()
79
80
81
    enable_all_optimization = True if tp_size > 1 else False
    extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
    plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
82
83
84
85
86
87
88
89
90
91
    booster = Booster(plugin=plugin)

    model = model_fn()
    new_model = model_fn()
    optimizer = HybridAdam(model.parameters(), lr=0.001)
    model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
    new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
    new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

    data = data_gen_fn()
92
    data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
93
94
95
96
97
98
99
100
101
102
103
    output = model(**data)
    output = output_transform_fn(output)
    output_key = list(output.keys())[0]
    loss = criterion(output[output_key])

    booster.backward(loss, optimizer)
    optimizer.step()

    with shared_tempdir() as tempdir:
        model_ckpt_path = f"{tempdir}/model"
        optimizer_ckpt_path = f"{tempdir}/optimizer"
104
105
106
        booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)

        booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
107
108
109
        dist.barrier()

        booster.load_model(new_model, model_ckpt_path)
110
111
112
        check_state_dict_equal(
            model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
        )
113
114

        booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
115
116
117
        check_state_dict_equal(
            optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
        )
118
119
120
121

        # Check the new model/optimizer can successfully run.
        data = data_gen_fn()
        data = {
122
            k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
123
124
125
126
127
128
129
130
131
        }
        output = new_model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])
        booster.backward(loss, new_optimizer)
        new_optimizer.step()
        booster.save_model(new_model, model_ckpt_path, shard=shard)
        booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
132
133


134
135
136
137
138
139
140
141
142
143
144
145
146
147
def exam_lazy_from_pretrained():
    llama_path = os.environ["LLAMA_PATH"]
    plugin = GeminiPlugin()
    booster = Booster(plugin=plugin)
    orig_model = LlamaForCausalLM.from_pretrained(llama_path)
    orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
    with LazyInitContext():
        model = LlamaForCausalLM.from_pretrained(llama_path)
    model, *_ = booster.boost(model)
    with shared_tempdir() as tempdir:
        save_path = os.path.join(tempdir, "model.pt")
        booster.save_model(model, save_path, shard=False)
        dist.barrier()
        state_dict = torch.load(save_path, map_location="cpu")
148
        check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
149
150


151
152
def run_dist(rank, world_size, port):
    config = {}
153
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
154
155
    exam_state_dict()
    exam_state_dict_with_origin()
156
    exam_lazy_from_pretrained()
157
158
159
160


@pytest.mark.dist
@rerun_if_address_is_in_use()
Frank Lee's avatar
Frank Lee committed
161
162
def test_gemini_ckpIO():
    spawn(run_dist, 4)
163
164

@pytest.mark.largedist
Frank Lee's avatar
Frank Lee committed
165
@skip_if_not_enough_gpus(min_gpus=8)
166
@rerun_if_address_is_in_use()
Frank Lee's avatar
Frank Lee committed
167
168
def test_gemini_ckpIO_3d():
    spawn(run_dist, 8)