Unverified Commit cd186441 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Refactor and enable multiprocess nn.Pipe benchmarks. (#319)



* mp cleanup

* round of multiprocess refactoring

* test golden run

* print cuda stats

* fix lint errors

* enable multiprocess pipe benchmarks

* set world size to be available gpus

* more changes

* use synthetic loaders for intermediate pipeline stages

* merged master

* fix for the devices property

* dataloader fix

* modify rank check

* print wps stats

* enable verification

* fix logging

* fix flag name

* fix flag name

* check for rank

* fix indent

* pass args

* pass args

* modify golden data

* remove unused print messsage

* fix lint errors

* add comments

* fix benchmarks
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent a2408eb8
......@@ -168,6 +168,12 @@ run_pipe_benchmark: &run_pipe_benchmark
command: |
python benchmarks/pipe.py
run_mp_pipe_benchmark: &run_mp_pipe_benchmark
- run:
name: Run Multiprocess Pipe Benchmark
command: |
python benchmarks/pipe.py --multiprocess
run_oss_benchmark: &run_oss_benchmark
- run:
name: Run OSS Benchmark
......@@ -444,6 +450,8 @@ jobs:
- <<: *run_pipe_benchmark
- <<: *run_mp_pipe_benchmark
- <<: *run_oss_benchmark
- <<: *run_oss_gloo
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import io
import tempfile
import torch
from torch.utils.data import DataLoader
......@@ -28,7 +29,8 @@ def get_real_dataloaders(args, benchmark_config):
"""Return real dataloaders for training, testing and validation."""
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root="/tmp"))
tmpdir = tempfile.TemporaryDirectory()
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=tmpdir.name))
tokenizer = get_tokenizer("basic_english")
def data_process(raw_text_iter):
......
......@@ -20,17 +20,24 @@ def get_benchmark_config():
"scaler": GradScaler(),
"clip_value": 0.05,
"batch_size": 8,
"num_decoder_layers": 10,
"seq_len": 32,
}
def get_golden_real_stats():
return {
"avg_wps": 703.778,
"std_dev_wps": 5.732,
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
}
def get_golden_real_stats(multiprocess=False):
if not multiprocess:
return {
"avg_wps": 703.778,
"std_dev_wps": 5.732,
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
}
else:
return {
"avg_wps": 647.404,
"std_dev_wps": 14.51,
"peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608],
}
def get_golden_synthetic_stats():
......
......@@ -7,7 +7,6 @@ import gc
import logging
import math
import operator
import os
import pprint
import time
......@@ -17,6 +16,7 @@ from golden_configs import lm_wikitext2
from models import transformer_lm
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import rpc
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
......@@ -29,6 +29,9 @@ from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim.oss import OSS
from fairscale.utils.testing import dist_init, get_worker_map
MPI_PORT = 29500
RPC_PORT = 29501
def init_random_seed(seed: int):
......@@ -64,7 +67,7 @@ def get_lm_model(args, device, config):
dropout = config["dropout"]
vocab_size = config["vocab_size"]
nhid = config["nhid"]
ndecoder = args.num_decoder_layers
ndecoder = config["num_decoder_layers"]
if args.lazy_construction:
layers = [
......@@ -179,13 +182,13 @@ def get_device(model, index):
if not torch.cuda.is_available():
return torch.device("cpu")
if model.devices:
if hasattr(model, "devices"):
return model.devices[index]
else:
return torch.cuda.current_device()
def get_fake_dataloader(lm_dataloader_len):
def get_fake_dataloader(lm_dataloader_len, args):
fake_input = {"input": torch.zeros(args.batch_size)}
class FakeDataset:
......@@ -224,7 +227,7 @@ def train(model_config, model, benchmark_config, args):
# TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
lm_dataloader = get_fake_dataloader(len(lm_dataloader))
lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config)
total_tokens = 0
total_tokens_per_log_interval = 0
......@@ -288,11 +291,12 @@ def train(model_config, model, benchmark_config, args):
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
print(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
)
)
)
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
......@@ -303,8 +307,10 @@ def train(model_config, model, benchmark_config, args):
raise RuntimeError(
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
)
return wps, loss.item()
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
return wps, loss.item()
else:
return 0.0, 0.0
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
......@@ -334,52 +340,64 @@ def get_number_of_words(data):
return data.size()[0] * data.size()[1]
def verify_lm_run(wps, golden_config):
"""Verify that words per second for a given benchmark run matches the golden data."""
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
print("Throughput(wps) is {:.2f}.".format(wps))
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
def verify_peak_memory(rank, golden_config, std_dev):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]))
current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
golden_ref = golden_config["peak_mem_usage"][rank]
if not current_device_usage < golden_ref * std_dev:
raise RuntimeError(
"Throughput(wps):{:.2f} is below the golden threshold of an "
"average value of {:.2f} and standard dev of {:.2f}.".format(
wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
)
"Peak memory usage for cuda device {:d} is {:d} which"
"is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
)
for i in range(4):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(i)["allocated_bytes.all.peak"]))
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110%
for i, golden_ref in zip(range(4), golden_config["peak_mem_usage"]):
current_device_usage = torch.cuda.memory_stats(i)["allocated_bytes.all.peak"]
if not current_device_usage < golden_ref * 1.1:
def verify_lm_run(wps, golden_config, args):
"""Verify that words per second for a given benchmark run matches the golden data."""
# Verify wps only on the last rank in multiprocess pipe
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
print("Throughput(wps) is {:.2f}.".format(wps))
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
raise RuntimeError(
"Peak memory usage for cuda device {:d} is {:d} which"
"is less than golden reference value of {:d}".format(i, current_device_usage, golden_ref)
"Throughput(wps):{:.2f} is below the golden threshold of an "
"average value of {:.2f} and standard dev of {:.2f}.".format(
wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
)
)
if args.multiprocess:
verify_peak_memory(dist.get_rank(), golden_config, 1.5)
else:
for i in range(4):
verify_peak_memory(i, golden_config, 1.1)
def benchmark_language_model(model_config, model, benchmark_config, args):
golden_config = get_golden_config(args.model_name)
golden_config = get_golden_config(args.model_name, args)
epoch = benchmark_config["epochs"]
print("-" * 110)
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
start_time = time.time()
if dist.get_rank() == dist.get_world_size() - 1:
print("-" * 110)
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
wps, loss = train(model_config, model, benchmark_config, args)
elapsed_time = time.time() - start_time
print("-" * 110)
print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
print("-" * 110)
if dist.get_rank() == dist.get_world_size() - 1:
print("-" * 110)
print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
print("-" * 110)
print("Throughput(wps) is {:.2f}.".format(wps))
print(
"Peak allocated bytes on cuda:{}: {:1d}".format(
dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
)
)
print("wps ", wps)
if len(model.balance) == 4:
if args.model_name == "lm":
verify_lm_run(wps, golden_config)
verify_lm_run(wps, golden_config, args)
else:
raise RuntimeError("Unrecognized args.model_name " % args.model_name)
......@@ -458,11 +476,11 @@ def create_benchmark_config(model_name):
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_golden_config(model_name):
def get_golden_config(model_name, args):
"""Return a dict with the golden data for throughput and memory usage."""
if model_name == "lm":
return lm_wikitext2.get_golden_real_stats()
return lm_wikitext2.get_golden_real_stats(args.multiprocess)
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
......@@ -470,6 +488,9 @@ def get_golden_config(model_name):
def benchmark_single_process(args):
"""Benchmark a given model using a single process and multiple devices."""
init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)
num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
assert num_devices > 0
init_random_seed(0)
......@@ -492,10 +513,10 @@ def benchmark_single_process(args):
def run_mp_worker(args, available_workers):
benchmark_config = create_benchmark_config(args.model_name)
model_config = create_model_config(args, config=benchmark_config)
model_config = create_model_config(args, benchmark_config=benchmark_config)
model = model_config["model"]
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
pipe_model = MultiProcessPipe(
model,
balance,
......@@ -512,7 +533,7 @@ def run_mp_worker(args, available_workers):
print(f"running all at once")
pipe_model.pipeline.all_at_once = True
if args.use_synthetic_data:
if args.dry_run:
train(model_config, pipe_model, benchmark_config, args)
else:
benchmark_language_model(model_config, pipe_model, benchmark_config, args)
......@@ -530,63 +551,27 @@ def run_worker(rank, world_size, args):
torch.distributed.destroy_process_group()
def bench_multi_process(args, all_at_once=False):
if args.local_world_size != 0:
world_size = args.local_world_size
else:
world_size = min(torch.cuda.device_count(), 2)
mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True)
best_device_map = {
0: "mlx5_0:1",
1: "mlx5_0:1",
2: "mlx5_1:1",
3: "mlx5_1:1",
4: "mlx5_2:1",
5: "mlx5_2:1",
6: "mlx5_3:1",
7: "mlx5_3:1",
}
def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10638"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name
torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size)
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank % torch.cuda.device_count())
def benchmark_multiprocess(rank, world_size, args):
init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
# TODO(anj-s): Add regression benchmarks for nccl as well.
torch.distributed.init_process_group(
backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
)
torch.cuda.set_device(rank % torch.cuda.device_count())
# TODO(anj-s): Move to TensorPipeRpcBackendOptions.
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
),
)
backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"}
if args.ddp_zero:
initialize_model_parallel(1, 4, **backends)
else:
initialize_model_parallel(1, world_size, **backends)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
......@@ -594,17 +579,12 @@ def bench_mpi(args):
parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size")
parser.add_argument("--world-size", "-w", type=int, default=0, help="world size")
parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0)
parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.")
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
......@@ -612,12 +592,7 @@ parser.add_argument(
parser.add_argument(
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument(
"--pipelined-backward", dest="pipelined_backward", action="store_true", help="Pipelined backward pass"
)
parser.add_argument(
"--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
)
parser.add_argument("--pipelined-backward", action="store_true", help="Pipelined backward pass")
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument(
......@@ -626,15 +601,16 @@ parser.add_argument(
default="lm",
help="Language Model(LM) used to benchmark nn.pipe.",
)
parser.set_defaults(pipelined_backward=True)
if __name__ == "__main__":
args = parser.parse_args()
# TODO(anj-s): Add support for multiprocess benchmarking.
if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
print(f"Running benchmark with args: {args}")
# TODO(anj-s): Remove print statements and introduce logging levels.
if not args.multiprocess:
print(f"Running single process benchmark with args: {args}")
benchmark_single_process(args)
else:
if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
print(f"Running benchmark with args: {args}")
bench_mpi(args)
world_size = max(torch.cuda.device_count(), 1)
print(f"Running multiprocess benchmark with args: {args}")
mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment