test_perf.py 5.25 KB
Newer Older
1
2
import time

3
import pytest
4
import torch
5
from torch.utils._pytree import tree_map
6
7
8
9
10

import colossalai
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
11
12
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
13
14
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
15
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
16
from tests.test_auto_parallel.test_offload.model_utils import *
17
from tests.test_tensor.common_utils import set_seed
18
19


20
21
22
@parameterize("model_name", ["gpt2_"])
@parameterize("memory_budget", [5000])
@parameterize("solver_name", ["asyn"])
23
def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
24
25
26
    # build model
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, data_gen = get_components_func()
27
28
29
30
31
32
33
34
35
    label = torch.randint(
        low=0,
        high=128,
        size=(
            64,
            8,
        ),
        device=get_current_device(),
    )
36
37
38
39
40
41
    criterion = LMLoss()

    set_seed(42)
    start_time = time.time()
    model = model_builder()
    model.train()
42
    param_size = parameter_size(model) / 1024**2 / 2
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    init_time = time.time() - start_time
    print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")

    data_args = data_gen(device="cpu")
    wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
    data_args = tree_map(wrap_fn, data_args)
    start_time = time.time()
    model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name)
    solver_time = time.time() - start_time
    print(f"solver_time={solver_time:.3f} s")

    hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
    optim = AMPOptimizer(hybrid_optimizer, model)

57
    with ColoInitContext(device=torch.device("cpu")):
58
59
60
61
        gemini_model = model_builder()
    gemini_model.train()

    hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
62
63
64
65
66
67
68
69
    gemini_config = dict(
        strict_ddp_mode=False,
        device=torch.device("cpu"),
        placement_policy="cpu",
        pin_memory=True,
        hidden_dim=8192,
        search_range_m=128,
    )
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
    optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
    gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()

    # test gemini
    time_list = []
    set_seed(42)
    data_args = data_gen(device="cuda")
    for step in range(10):
        gemini_optim.zero_grad()
        torch.cuda.synchronize()
        start_time = time.time()
        gemini_out = gemini_model(**data_args)
        gemini_loss = criterion(gemini_out, label)
        gemini_optim.backward(gemini_loss)
        torch.cuda.synchronize()
        time_list.append(time.time() - start_time)
        gemini_optim.step()

    torch.cuda.synchronize()

    exec_time = sum(sorted(time_list)[:5]) / 5
96
97
    runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
    runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
98
99
100
101
102
    print(f"gemini | model_name: {model_name}")
    print(
        f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB "
        f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|"
    )
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    print(time_list)

    del data_args
    del gemini_model
    del gemini_optim
    del gemini_out
    del gemini_loss

    # test asyn offload
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()

    time_list = []
    set_seed(42)
    data_args = data_gen(device="cuda")
    data_args = tree_map(wrap_fn, data_args)
    for step in range(10):
        optim.zero_grad()
        torch.cuda.synchronize()
        start_time = time.time()
        loss = criterion(model(**data_args), label)
        optim.backward(loss)
        torch.cuda.synchronize()
        time_list.append(time.time() - start_time)
        optim.step()

    torch.cuda.synchronize()

    exec_time = sum(sorted(time_list)[:5]) / 5
133
134
    runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
    runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
135
136
137
138
139
    print(f"solver_name: {solver_name} | model_name: {model_name}")
    print(
        f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB "
        f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|"
    )
140
141
    print(time_list)

142
143

def run_dist(rank, world_size, port):
144
    config = {}
145
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
146
147
148
    exam_fwd_bwd()


149
@pytest.mark.skip("this test failed")
150
@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed")
151
@rerun_if_address_is_in_use()
152
def test_perf():
153
    spawn(run_dist, 1)
154
155


156
if __name__ == "__main__":
157
    test_perf()