pipe.py 11 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
3
4
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
5

6
7
from collections import defaultdict
import gc
8
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
9
10
11
12
import math
import time

import torch
13
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
14
from torch.distributed import rpc
15
from torch.nn.parallel import DistributedDataParallel as DDP
16
import utils
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
17

18
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
19
from fair_dev.testing.testing import dist_init
20
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
21
from fairscale.nn.model_parallel import initialize_model_parallel
22

23
24
25
MPI_PORT = 29500
RPC_PORT = 29501

26

Tom Birch's avatar
Tom Birch committed
27
28
29
30
31
32
33
34
35
36
37
38
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


39
40
41
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
42

43
44
    if not torch.cuda.is_available():
        return torch.device("cpu")
45
    if hasattr(model, "devices"):
46
47
48
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
49

50

51
def get_fake_dataloader(lm_dataloader_len, args):
52
53
54
55
56
57
58
59
60
61
62
63
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


anj-s's avatar
anj-s committed
64
def train(model_config, model, benchmark_config, model_specs, args):
65
    lm_dataloader, _, _ = utils.get_data_loader(model_config["dataset_info"], args, benchmark_config, model_specs)
66
    criterion = benchmark_config["criterion"]
anj-s's avatar
anj-s committed
67
    vocab_size = model_specs["vocab_size"]
68
    optimizer = model_config["optimizer"]
69
70

    model.train()
71
    utils.log_number_of_parameters(model)
72
73
74
75
76

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
77

78
    pipe_group = model.group if hasattr(model, "group") else None
79

80
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
81
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
82
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
83
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs)
84

85
86
    total_tokens = 0
    total_tokens_per_log_interval = 0
87
88
89
90
91
92
93
94
95
96
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
97
    for i, batch in enumerate(lm_dataloader):
98
99
100
101
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
102
103
        if args.max_batch and i > args.max_batch:
            break
104
105
106

        if i > 0:
            total_tokens += source.numel()
107

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
108
        optimizer.zero_grad()
109
        try:
110
            if pipe_group is None or pipe_group.rank() == 0:
111
                tmp = source.to(get_device(model, 0))
112
113
                output = model(tmp)
            else:
114
                output = model(source)
115
116
117
118
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
119
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
120
121
122
            output = output.to(target.device)
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            loss.backward()
123
            del target
Tom Birch's avatar
Tom Birch committed
124
        else:
125
            model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
126
127
128

        del output

anj-s's avatar
anj-s committed
129
        torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
Tom Birch's avatar
Tom Birch committed
130
131
        optimizer.step()

132
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
133
134
            total_loss += loss.item()
            log_interval = 1
135
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
136
137
138
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
139
                if dist.get_rank() == dist.get_world_size() - 1:
140
                    logging.debug(
141
142
143
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
144
                    )
145
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
146
147
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
148

149
150
151
152
153
154
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
155
    if dist.get_rank() == dist.get_world_size() - 1:
156
157
158
        return wps, loss.item()
    else:
        return 0.0, 0.0
159

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
160

161
162
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
163
164
    eval_model.eval()
    total_loss = 0.0
165
166
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
167
168
169
170
171
172
173

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
174
175
176
177
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
178
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
179
180
181
182
183
184
185
186
187
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


188
def verify_peak_memory(rank, golden_config, std_dev):
189
190
191
    logging.debug(
        "Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])
    )
192
193
194
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
195
        raise RuntimeError(
196
197
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
198
        )
199
200


201
202
203
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

204
    if dist.get_rank() == dist.get_world_size() - 1:
205
206
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
207
        logging.info("Throughput(wps) is {:.2f}.".format(wps))
208
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
209
            raise RuntimeError(
210
211
212
213
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
214
            )
215

216
217
    for i in range(4):
        verify_peak_memory(i, golden_config, 1.1)
218

219

220
221
def benchmark_language_model(model_config, model, benchmark_config, model_specs, config_class, args):
    golden_config = get_golden_config(args.model_name, config_class, args)
222
    epoch = benchmark_config["epochs"]
223
    start_time = time.time()
224
    if dist.get_rank() == dist.get_world_size() - 1:
225
226
227
        logging.debug("-" * 110)
        logging.debug("| start of epoch {:1d}".format(epoch))
        logging.debug("-" * 110)
anj-s's avatar
anj-s committed
228
    wps, loss = train(model_config, model, benchmark_config, model_specs, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
229
    elapsed_time = time.time() - start_time
230
    if dist.get_rank() == dist.get_world_size() - 1:
231
232
233
234
235
        logging.debug("-" * 110)
        logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        logging.debug("-" * 110)
        logging.debug("Throughput(wps) is {:.2f}.".format(wps))
    logging.debug(
236
237
238
239
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
240

241
    if len(model.balance) == 4:
242
        if args.model_name == "lm":
243
            verify_lm_run(wps, golden_config, args)
244
245
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260


def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
261

262
def get_golden_config(model_name, config_class, args):
263
264
265
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
266
        return config_class.get_golden_real_stats()
267
268
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
269
270


271
def benchmark_single_process(config_class, args):
272
273
    """Benchmark a given model using a single process and multiple devices."""

274
275
276
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

277
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
278
    assert num_devices > 0
279
    utils.init_random_seed(0)
Tom Birch's avatar
Tom Birch committed
280

281
282
283
    benchmark_config = utils.create_benchmark_config(args.model_name, config_class)
    model_specs = utils.get_model_specs(args.model_name, config_class)
    model_config = utils.create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
284
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
285

286
    balance = generate_balance(min(num_devices, 4), len(model))
287
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
288
    del model
289
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
290

291
    if args.dry_run:
anj-s's avatar
anj-s committed
292
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
293
    else:
294
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, config_class, args)
Tom Birch's avatar
Tom Birch committed
295
296
297
298
299
300
301


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
302
    utils.init_random_seed(0)
Tom Birch's avatar
Tom Birch committed
303
304
305
306
307
308
309
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
310
    args = utils.init_args()
311
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
312

313
    logging.info(f"Running single process benchmark with args: {args}")
314
    benchmark_single_process(lm_wikitext2, args)