Unverified Commit 656fc319 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Modify train and benchmark functions to account for multiple models and datasets. (#260)



* [refactor]Remove unused variables and refactor common configurations

* move helper function to call site

* fixed lint errors

* fix lint errors

* fix lint errors

* fix lint errors

* fix import order

* format files

* remove unused imports

* fix lint errors

* fix lint errors

* refactor common utilities

* address PR comments

* sorted imports

* add space

* modify comment

* added doc strings and addressed PR comments.

* addressed PR comments

* added another comment to clarify.

* fixing lint errors

* addressed PR comments

* addressed PR comments

* fixed typos

* initialize var

* rename seq_pred to lm

* fix lint errors
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent a21f50f9
......@@ -14,7 +14,7 @@ import time
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
import datasets
import models
import numpy
import numpy as np
import torch
from torch.distributed import rpc
import torch.multiprocessing as mp
......@@ -44,17 +44,36 @@ def init_random_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
numpy.random.seed(seed)
np.random.seed(seed)
def make_model(args, device, config):
def get_model_and_optimizer(args, device, config):
"""Return instantiated model and optimizer function."""
if args.model_name == "lm":
model = get_lm_model(args, device, config)
lr = config["lr"]
def make_adam(params):
if args.ddp_zero:
return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
else:
return Adam(params, lr=lr)
optimizer = make_adam
return model, optimizer
def get_lm_model(args, device, config):
"""Get language model(based on GPT-2) used for sequence prediction."""
ninp = config["ninp"]
nhead = config["nhead"]
initrange = config["initrange"]
dropout = config["dropout"]
vocab_size = config["vocab_size"]
nhid = config["nhid"]
lr = config["lr"]
ndecoder = args.num_decoder_layers
if args.lazy_construction:
......@@ -70,14 +89,7 @@ def make_model(args, device, config):
else:
model = models.TransformerLMSequntial(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
def make_adam(params):
if args.ddp_zero:
return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
else:
return Adam(params, lr=lr)
optimizer = make_adam
return model, optimizer
return model
def get_tensors_by_size_bucket():
......@@ -225,9 +237,13 @@ def train(data_config, model, benchmark_config, args):
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
lm_dataloader = get_fake_dataloader(len(lm_dataloader))
total_tokens = 0
total_tokens_per_log_interval = 0
for i, batch in enumerate(lm_dataloader):
if args.max_batch and i > args.max_batch:
break
total_tokens += batch["input"].numel()
optimizer.zero_grad()
try:
if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
......@@ -263,23 +279,28 @@ def train(data_config, model, benchmark_config, args):
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
total_loss += loss.item()
log_interval = 1
word_counter += batch["ntokens"]
total_tokens_per_log_interval += batch["input"].numel()
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, word_counter / elapsed, cur_loss, math.exp(cur_loss)
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
)
)
word_counter = 0
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
return total_tokens, loss.item()
def evaluate(eval_model, data_source, criterion, bptt, ntokens):
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
eval_model.eval()
total_loss = 0.0
# TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
bptt = 35
def get_batch(source, i, bptt):
seq_len = min(bptt, len(source) - 1 - i)
......@@ -301,58 +322,45 @@ def get_number_of_words(data):
return data.size()[0] * data.size()[1]
def verify_lm_run(wps):
"""Verify that words per second for a given benchmark run matches the golden data."""
# Assert that words per second is within 3 standard deviations of the average
# of six golden runs
assert wps > 36954.4 - (3 * 116.825)
for i in range(4):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(i)["allocated_bytes.all.peak"]))
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110%
for i, golden_ref in zip(range(4), [4061909504, 4050944, 10427392, 2031824896]):
assert torch.cuda.memory_stats(i)["allocated_bytes.all.peak"] < golden_ref * 1.1
def benchmark_language_model(model_config, model, benchmark_config, args):
ntokens, train_data, val_data, test_data = model_config["data"]
optimizer = model_config["optimizer"]
criterion = benchmark_config["criterion"]
epoch = 1
bptt = 35
start_time = time.time()
print("-" * 110)
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
epoch_start_time = time.time()
train(train_data, model, criterion, optimizer, bptt, ntokens, args)
val_loss = 1 # evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 89)
print(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
epoch, (time.time() - epoch_start_time), val_loss
)
)
print("-" * 110)
start_time = time.time()
n_words, loss = train(data_config, model, benchmark_config, args)
elapsed_time = time.time() - start_time
nwords = get_number_of_words(train_data) + get_number_of_words(val_data)
wps = nwords / elapsed_time
test_loss = 1 # evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 89)
print(
"| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, elapsed_time, nwords, wps
)
)
print("=" * 110)
print("-" * 110)
print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
print("-" * 110)
if can_benchmark and len(model.balance) == 4:
# Assert that words per second is within 3 standard deviations of the average
# of six golden runs
assert wps > 36954.4 - (3 * 116.825)
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:2: {:1d}".format(torch.cuda.memory_stats(2)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:3: {:1d}".format(torch.cuda.memory_stats(3)["allocated_bytes.all.peak"]))
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 4061909504 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 4050944 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 10427392 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 2031824896 * 1.1
print("No regression detected")
if args.model_name == "lm":
verify_lm_run(wps)
else:
raise RuntimeError("Unrecognized args.model_name " % args.model_name)
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
......@@ -380,22 +388,42 @@ def generate_balance(num_devices, num_layers):
return balance
def make_model_and_data(args, config=None):
"""Return a dict with the given model, dataset and optimizer."""
def get_synthetic_dataloader(args):
"""Returns dataloader for synthetic data."""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if args.use_synthetic_data:
model, optimizer = make_model(args, device, config)
if args.model_name == "lm":
lm_dataset = BenchmarkLMDataset()
lm_dataloader = DataLoader(
lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
)
return {"model": model, "optimizer": optimizer, "data": lm_dataloader}
return lm_dataloader
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_real_dataloaders(device, config):
"""Returns dataloaders for real data."""
if args.model_name == "lm":
data = datasets.get_wikitext2_data(device)
ntokens, _, _, _ = data
config["vocab_size"] = ntokens
model, optimizer = make_model(args, device, ntokens)
return data
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def create_model_config(args, config=None):
"""Return a dict with the given model, dataset and optimizer."""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if args.use_synthetic_data:
model, optimizer = get_model_and_optimizer(args, device, config)
dataloader = get_synthetic_dataloader(args)
return {"model": model, "optimizer": optimizer, "data": dataloader}
else:
data = get_real_dataloaders(device, config)
model, optimizer = get_model_and_optimizer(args, device, config)
return {
"model": model,
"optimizer": optimizer,
......@@ -406,7 +434,7 @@ def make_model_and_data(args, config=None):
def create_benchmark_config(model_name):
"""Return a dict with configurations required for benchmarking `model_name` model."""
if model_name == "seq_pred":
if model_name == "lm":
return {
"vocab_size": 10000,
"ninp": 2048, # embedding dimension
......@@ -419,6 +447,8 @@ def create_benchmark_config(model_name):
"scaler": GradScaler(),
"clip_value": 0.05,
}
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def benchmark_single_process(args):
......@@ -429,30 +459,30 @@ def benchmark_single_process(args):
init_random_seed(0)
benchmark_config = create_benchmark_config(args.model_name)
model_config = make_model_and_data(args, config=benchmark_config)
model_config = create_model_config(args, config=benchmark_config)
model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(
pipe_model = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
del model
del model_config["model"]
if args.use_synthetic_data:
train(model_config, p, benchmark_config, args)
train(model_config, pipe_model, benchmark_config, args)
else:
benchmark_language_model(model_config, p, benchmark_config, args)
benchmark_language_model(model_config, pipe_model, benchmark_config, args)
def run_mp_worker(args, available_workers):
benchmark_config = create_benchmark_config(args.model_name)
model_config = make_model_and_data(args, config=benchmark_config)
model_config = create_model_config(args, config=benchmark_config)
model = model_config["model"]
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
p = pipe.Pipe(
pipe_model = pipe.Pipe(
model,
balance,
style=Pipe.AsyncSchedule,
......@@ -464,15 +494,15 @@ def run_mp_worker(args, available_workers):
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
)
if torch.cuda.is_available():
p = p.cuda()
if args.all_at_once and p.pipeline:
pipe_model = pipe_model.cuda()
if args.all_at_once and pipe_model.pipeline:
print(f"running all at once")
p.pipeline.all_at_once = True
pipe_model.pipeline.all_at_once = True
if args.use_synthetic_data:
train(model_config, p, benchmark_config, args)
train(model_config, pipe_model, benchmark_config, args)
else:
benchmark_language_model(model_config, p, benchmark_config, args)
benchmark_language_model(model_config, pipe_model, benchmark_config, args)
def run_worker(rank, world_size, args):
......@@ -577,7 +607,10 @@ parser.add_argument(
)
parser.add_argument("--use_synthetic_data", default=True, help="Uses synthetic data for a sample training run.")
parser.add_argument(
"--model_name", default="seq_pred", choices=["seq_pred", "transformer"], help="Model used to benchmark pipe."
# TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
"--model_name",
default="lm",
help="Language Model(LM) used to benchmark nn.pipe.",
)
parser.set_defaults(pipelined_backward=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment