bench_utils.py 5.93 KB
Newer Older
1
import time
2
from copy import deepcopy
3
4
5
6
from typing import Callable, Tuple

import numpy as np
import torch
7
8
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
9
10
11
12
13

from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace


14
15
16
def bench(
    gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5
) -> Tuple[int, int]:
17
18
19
20
21
22
23
24
25
    """Benchmarking a given graph module
    Args:
        gm (torch.fx.GraphModule): The graph module to benchmark.
        criterion (torch.nn.Module): Loss function.
        data_gen (Callable): Data generator.
        num_steps (int, optional): Number of test steps. Defaults to 5.
    Returns:
        Tuple[int, int]: peak memory in MB and step time in MS.
    """
26
27
    gm.train()
    gm.cuda()
28
    step_time = float("inf")
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    cached = torch.cuda.max_memory_allocated(device="cuda")
    try:
        for _ in range(num_steps):
            args, label = data_gen()
            output, loss = None, None

            torch.cuda.synchronize(device="cuda")
            start = time.time()
            output = gm(*args)
            loss = criterion(output, label)
            loss.backward()
            torch.cuda.synchronize(device="cuda")
            step_time = min(step_time, time.time() - start)

            for child in gm.children():
                for param in child.parameters():
                    param.grad = None
            del args, label, output, loss
    except:
        del args, label, output, loss
    gm.to("cpu")
    torch.cuda.empty_cache()
54
55
    peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
    return peak_mem, step_time * 1.0e3
56
57


58
59
60
61
62
63
64
65
66
def bench_rotor(
    gm: torch.fx.GraphModule,
    criterion: torch.nn.Module,
    data_gen: Callable,
    num_steps: int = 5,
    sample_points: int = 20,
    free_memory: int = torch.cuda.mem_get_info()[0],
    start_factor: int = 4,
) -> Tuple[np.array, list, list]:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """Auto Checkpoint Rotor Algorithm benchmarking
    Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
    Args:
        gm (torch.fx.GraphModule): The graph module to benchmark.
        criterion (torch.nn.Module): Loss function.
        data_gen (Callable): Data generator.
        num_steps (int, optional): Number of test steps. Defaults to 5.
        sample_points (int, optional): Number of sample points. Defaults to 20.
        free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
        start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
        will be free_memory / start_factor. Defaults to 4.
    Returns:
        Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
    """
81
    peak_hist, step_hist = [], []
82
83
    raw_graph = deepcopy(gm.graph)
    for budget in np.linspace(free_memory // start_factor, free_memory, sample_points):
84
85
86
        gm = metainfo_trace(gm, *data_gen()[0])
        solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
        try:
87
88
            gm.graph = solver.solve(verbose=False)
            peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
89
        except:
90
            peak_memory, step_time = budget / 1024**2, float("inf")
91
92
        peak_hist.append(peak_memory)
        step_hist.append(step_time)
93
94
95
96
97
98
99
100
101
        gm.graph = deepcopy(raw_graph)
    return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist


class GPTLMModel(nn.Module):
    """
    GPT Model
    """

102
103
104
105
106
107
108
109
110
    def __init__(
        self,
        hidden_size=768,
        num_layers=12,
        num_attention_heads=12,
        max_seq_len=1024,
        vocab_size=50257,
        checkpoint=False,
    ):
111
112
113
        super().__init__()
        self.checkpoint = checkpoint
        self.model = GPT2LMHeadModel(
114
115
116
117
118
119
120
121
122
            GPT2Config(
                n_embd=hidden_size,
                n_layer=num_layers,
                n_head=num_attention_heads,
                n_positions=max_seq_len,
                n_ctx=max_seq_len,
                vocab_size=vocab_size,
            )
        )
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        if checkpoint:
            self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, attention_mask):
        # Only return lm_logits
        return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]


class GPTLMLoss(nn.Module):
    """
    GPT Loss
    """

    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def gpt2_medium(checkpoint=False):
    return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_xl(checkpoint=False):
    return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)


def gpt2_6b(checkpoint=False):
    return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
157
158


159
def data_gen_gpt2(batch_size, seq_len, vocab_size, device="cuda:0"):
160
161
162
163
164
165
166
167
    """
    Generate random data for gpt2 benchmarking
    """
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    attention_mask = torch.ones_like(input_ids, device=device)
    return (input_ids, attention_mask), attention_mask


168
def data_gen_resnet(batch_size, shape, device="cuda:0"):
169
170
171
172
173
174
    """
    Generate random data for resnet benchmarking
    """
    data = torch.empty(batch_size, *shape, device=device)
    label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
    return (data,), label