test_perf.py 5.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import time
import pytest
from functools import partial

import torch
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp

import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import free_port, get_current_device
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.testing import parameterize

from tests.test_tensor.common_utils import set_seed
from tests.test_auto_parallel.test_offload.model_utils import *


@parameterize('model_name', ['gpt2_'])
@parameterize('memory_budget', [5000])
@parameterize('solver_name', ['asyn'])
def exam_fwd_bwd(
        model_name: str,
        memory_budget: float,
        solver_name: str
):

    # build model
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, data_gen = get_components_func()
    label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
    criterion = LMLoss()

    set_seed(42)
    start_time = time.time()
    model = model_builder()
    model.train()
    param_size = parameter_size(model) / 1024 ** 2 / 2
    init_time = time.time() - start_time
    print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")

    data_args = data_gen(device="cpu")
    wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
    data_args = tree_map(wrap_fn, data_args)
    start_time = time.time()
    model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name)
    solver_time = time.time() - start_time
    print(f"solver_time={solver_time:.3f} s")

    hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
    optim = AMPOptimizer(hybrid_optimizer, model)

    with ColoInitContext(device=torch.device('cpu')):
        gemini_model = model_builder()
    gemini_model.train()

    hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
    gemini_config = dict(strict_ddp_mode=False,
                         device=torch.device('cpu'),
                         placement_policy='cpu',
                         pin_memory=True,
                         hidden_dim=8192,
                         search_range_mb=128)
    gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
    optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
    gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()

    # test gemini
    time_list = []
    set_seed(42)
    data_args = data_gen(device="cuda")
    for step in range(10):
        gemini_optim.zero_grad()
        torch.cuda.synchronize()
        start_time = time.time()
        gemini_out = gemini_model(**data_args)
        gemini_loss = criterion(gemini_out, label)
        gemini_optim.backward(gemini_loss)
        torch.cuda.synchronize()
        time_list.append(time.time() - start_time)
        gemini_optim.step()

    torch.cuda.synchronize()

    exec_time = sum(sorted(time_list)[:5]) / 5
    runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
    runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
    print(f'gemini | model_name: {model_name}')
    print(
        f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
        f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
    )
    print(time_list)

    del data_args
    del gemini_model
    del gemini_optim
    del gemini_out
    del gemini_loss

    # test asyn offload
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()

    time_list = []
    set_seed(42)
    data_args = data_gen(device="cuda")
    data_args = tree_map(wrap_fn, data_args)
    for step in range(10):
        optim.zero_grad()
        torch.cuda.synchronize()
        start_time = time.time()
        loss = criterion(model(**data_args), label)
        optim.backward(loss)
        torch.cuda.synchronize()
        time_list.append(time.time() - start_time)
        optim.step()

    torch.cuda.synchronize()

    exec_time = sum(sorted(time_list)[:5]) / 5
    runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
    runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
    print(f'solver_name: {solver_name} | model_name: {model_name}')
    print(
        f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
        f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
    )
    print(time_list)

@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def test_perf(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    exam_fwd_bwd()


if __name__ == '__main__':
    run_func = partial(test_perf, world_size=1, port=free_port())
    mp.spawn(run_func, nprocs=1)