test_gemini_checkpoint_io.py 6.14 KB
Newer Older
1
import os
2
3
4

import pytest
import torch
5
import torch.distributed as dist
6
from transformers import LlamaForCausalLM
7
from utils import shared_tempdir
8
9

import colossalai
10
11
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
12
from colossalai.lazy import LazyInitContext
13
from colossalai.nn.optimizer import HybridAdam
14
15
16
17
18
19
20
from colossalai.testing import (
    check_state_dict_equal,
    clear_cache_before_run,
    parameterize,
    rerun_if_address_is_in_use,
    spawn,
)
21
from tests.kit.model_zoo import model_zoo
22

23
MODEL_PLACEMENT_CONFIGS = [
24
25
26
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
27
28
29
]

OPTIM_PLACEMENT_CONFIGS = [
30
31
32
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0},  # zero2-offload
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5},  # zero2-offload-half
33
34
]

35

36
@clear_cache_before_run()
37
38
39
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
40
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool):
41
    from transformers import BertForSequenceClassification
42

43
    (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
44
    bert_model = model_fn()
45

46
    with shared_tempdir() as tempdir:
47
        pretrained_path = os.path.join(tempdir, "pretrained")
48
49
        bert_model.config.save_pretrained(save_directory=pretrained_path)

50
        plugin = GeminiPlugin(**placement_config)
51
52
        booster = Booster(plugin=plugin)
        bert_model, _, _, _, _ = booster.boost(bert_model)
53
        model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
54

55
56
57
        booster.save_model(
            bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors
        )
58
        dist.barrier()
59

60
        new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
61
        check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
62
63


64
@clear_cache_before_run()
65
66
67
68
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32])
69
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int):
70
    (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
71
    criterion = lambda x: x.mean()
72
    plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14))
73
74
75
76
77
78
79
80
81
82
    booster = Booster(plugin=plugin)

    model = model_fn()
    new_model = model_fn()
    optimizer = HybridAdam(model.parameters(), lr=0.001)
    model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
    new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
    new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

    data = data_gen_fn()
83
    data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
84
85
86
87
88
89
90
91
92
93
94
    output = model(**data)
    output = output_transform_fn(output)
    output_key = list(output.keys())[0]
    loss = criterion(output[output_key])

    booster.backward(loss, optimizer)
    optimizer.step()

    with shared_tempdir() as tempdir:
        model_ckpt_path = f"{tempdir}/model"
        optimizer_ckpt_path = f"{tempdir}/optimizer"
95
96
97
        booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)

        booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
98
99
100
        dist.barrier()

        booster.load_model(new_model, model_ckpt_path)
101
102
103
        check_state_dict_equal(
            model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
        )
104
105

        booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
106
107
108
        check_state_dict_equal(
            optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
        )
109
110
111
112

        # Check the new model/optimizer can successfully run.
        data = data_gen_fn()
        data = {
113
            k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
114
115
116
117
118
119
120
121
122
        }
        output = new_model(**data)
        output = output_transform_fn(output)
        output_key = list(output.keys())[0]
        loss = criterion(output[output_key])
        booster.backward(loss, new_optimizer)
        new_optimizer.step()
        booster.save_model(new_model, model_ckpt_path, shard=shard)
        booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
def exam_lazy_from_pretrained():
    llama_path = os.environ["LLAMA_PATH"]
    plugin = GeminiPlugin()
    booster = Booster(plugin=plugin)
    orig_model = LlamaForCausalLM.from_pretrained(llama_path)
    orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
    with LazyInitContext():
        model = LlamaForCausalLM.from_pretrained(llama_path)
    model, *_ = booster.boost(model)
    with shared_tempdir() as tempdir:
        save_path = os.path.join(tempdir, "model.pt")
        booster.save_model(model, save_path, shard=False)
        dist.barrier()
        state_dict = torch.load(save_path, map_location="cpu")
139
        check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
140
141


142
143
def run_dist(rank, world_size, port):
    config = {}
144
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
145
146
    exam_state_dict()
    exam_state_dict_with_origin()
147
    exam_lazy_from_pretrained()
148
149
150


@pytest.mark.dist
151
@pytest.mark.parametrize("world_size", [2])
152
153
154
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
    spawn(run_dist, world_size)