bench_utils.py 2.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import time
from functools import partial
from typing import Callable, Tuple

import numpy as np
import torch
import torchvision.models as tm

from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace


def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5):
    gm.train()
    gm.cuda()
    step_time = float('inf')
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    cached = torch.cuda.max_memory_allocated(device="cuda")
    try:
        for _ in range(num_steps):
            args, label = data_gen()
            output, loss = None, None

            torch.cuda.synchronize(device="cuda")
            start = time.time()
            output = gm(*args)
            loss = criterion(output, label)
            loss.backward()
            torch.cuda.synchronize(device="cuda")
            step_time = min(step_time, time.time() - start)

            for child in gm.children():
                for param in child.parameters():
                    param.grad = None
            del args, label, output, loss
    except:
        del args, label, output, loss
    gm.to("cpu")
    torch.cuda.empty_cache()
    return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3


def bench_rotor(gm: torch.fx.GraphModule,
                criterion: torch.nn.Module,
                data_gen: Callable,
                num_steps: int = 5,
                sample_points: int = 20,
                free_memory: int = torch.cuda.mem_get_info()[0]):
    peak_hist, step_hist = [], []
    for budget in np.linspace(free_memory // 5, free_memory, sample_points):
        gm = metainfo_trace(gm, *data_gen()[0])
        solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
        try:
            gm.graph = solver.solve()
            peak_memory, step_time = bench(gm,
                                           criterion,
                                           partial(data_gen, batch_size=2048, shape=(3, 224, 224)),
                                           num_steps=num_steps)
        except:
            peak_memory, step_time = budget / 1024**2, float('inf')
        peak_hist.append(peak_memory)
        step_hist.append(step_time)
    return peak_hist, step_hist