Unverified Commit b89365e6 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

split benchmark configs (#420)

parent 8876553e
......@@ -21,11 +21,11 @@ def _batchify(data, batch_size):
return data
def _get_total_batch_size(benchmark_config):
return benchmark_config["seq_len"] * benchmark_config["batch_size"]
def _get_total_batch_size(benchmark_config, model_specs):
return model_specs["seq_len"] * benchmark_config["batch_size"]
def get_real_dataloaders(args, benchmark_config):
def get_real_dataloaders(args, benchmark_config, model_specs):
"""Return real dataloaders for training, testing and validation."""
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
......@@ -47,21 +47,21 @@ def get_real_dataloaders(args, benchmark_config):
batch_size = args.batch_size
return _batchify(data, batch_size)
total_batch_size = _get_total_batch_size(benchmark_config)
total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify)
valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify)
test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify)
return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader
def get_synthetic_dataloaders(args, benchmark_config):
def get_synthetic_dataloaders(args, benchmark_config, model_specs):
"""Return synthetic dataloaders for training, testing and validation."""
def batchify(data):
batch_size = args.batch_size
return _batchify(data, batch_size)
total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config)
total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
# vocab_size is 10000 and length of the real data is 2049990.
lm_dataset = torch.randint(1, 10000, (2049990,))
......
......@@ -5,26 +5,31 @@ import torch.nn as nn
from fairscale.optim import GradScaler
def get_benchmark_config():
def get_model_config():
return {
"epochs": 1,
"vocab_size": 10000,
"ninp": 2048, # embedding dimension
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0,
"initrange": 0.1,
"criterion": nn.CrossEntropyLoss(),
"lr": 0.001, # learning rate
"scaler": GradScaler(),
"clip_value": 0.05,
"batch_size": 8,
"num_decoder_layers": 10,
"seq_len": 32,
}
def get_benchmark_config():
return {
"epochs": 1,
"lr": 0.001, # learning rate
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
}
def get_golden_real_stats(multiprocess=False):
if not multiprocess:
return {
......
......@@ -40,13 +40,13 @@ def init_random_seed(seed: int):
np.random.seed(seed)
def get_model_and_optimizer(args, device, config):
def get_model_and_optimizer(args, device, benchmark_config, model_config):
"""Return instantiated model and optimizer function."""
if args.model_name == "lm":
model = get_lm_model(args, device, config)
model = get_lm_model(args, device, model_config)
lr = config["lr"]
lr = benchmark_config["lr"]
def make_adam(params):
if args.ddp_zero:
......@@ -201,10 +201,10 @@ def get_fake_dataloader(lm_dataloader_len, args):
return FakeDataset()
def train(model_config, model, benchmark_config, args):
def train(model_config, model, benchmark_config, model_specs, args):
lm_dataloader, _, _ = model_config["data"]
criterion = benchmark_config["criterion"]
vocab_size = benchmark_config["vocab_size"]
vocab_size = model_specs["vocab_size"]
optimizer = model_config["optimizer"]
model.train()
......@@ -227,7 +227,7 @@ def train(model_config, model, benchmark_config, args):
# TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config)
lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs)
total_tokens = 0
total_tokens_per_log_interval = 0
......@@ -281,7 +281,7 @@ def train(model_config, model, benchmark_config, args):
del output
torch.nn.utils.clip_grad_value_(model.parameters(), benchmark_config["clip_value"])
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
optimizer.step()
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
......@@ -374,7 +374,7 @@ def verify_lm_run(wps, golden_config, args):
verify_peak_memory(i, golden_config, 1.1)
def benchmark_language_model(model_config, model, benchmark_config, args):
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
golden_config = get_golden_config(args.model_name, args)
epoch = benchmark_config["epochs"]
start_time = time.time()
......@@ -382,7 +382,7 @@ def benchmark_language_model(model_config, model, benchmark_config, args):
print("-" * 110)
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
wps, loss = train(model_config, model, benchmark_config, args)
wps, loss = train(model_config, model, benchmark_config, model_specs, args)
elapsed_time = time.time() - start_time
if dist.get_rank() == dist.get_world_size() - 1:
print("-" * 110)
......@@ -427,28 +427,28 @@ def generate_balance(num_devices, num_layers):
return balance
def get_synthetic_dataloaders(args, benchmark_config):
def get_synthetic_dataloaders(args, benchmark_config, model_specs):
"""Returns dataloader for synthetic data."""
if args.model_name == "lm":
return get_synthetic_wikitext2_dataloaders(args, benchmark_config)
return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_real_dataloaders(args, device, benchmark_config):
def get_real_dataloaders(args, device, benchmark_config, model_specs):
"""Returns dataloaders for real data."""
if args.model_name == "lm":
data = get_real_wikitext2_dataloaders(args, benchmark_config)
data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
ntokens, train_dataloader, valid_dataloader, test_dataloader = data
benchmark_config["vocab_size"] = ntokens
model_specs["vocab_size"] = ntokens
return train_dataloader, valid_dataloader, test_dataloader
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def create_model_config(args, benchmark_config=None):
def create_model_config(args, benchmark_config=None, model_specs=None):
"""Return a dict with the given model, dataset and optimizer."""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
......@@ -458,8 +458,8 @@ def create_model_config(args, benchmark_config=None):
else:
dataloader_fn = get_real_dataloaders
data = dataloader_fn(args, device, benchmark_config)
model, optimizer = get_model_and_optimizer(args, device, benchmark_config)
data = dataloader_fn(args, device, benchmark_config, model_specs)
model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
return {
"model": model,
"optimizer": optimizer,
......@@ -476,6 +476,15 @@ def create_benchmark_config(model_name):
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_model_specs(model_name):
"""Return a dict with configurations required for configuring `model_name` model."""
if model_name == "lm":
return lm_wikitext2.get_model_config()
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_golden_config(model_name, args):
"""Return a dict with the golden data for throughput and memory usage."""
......@@ -496,7 +505,8 @@ def benchmark_single_process(args):
init_random_seed(0)
benchmark_config = create_benchmark_config(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config)
model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model))
......@@ -505,15 +515,16 @@ def benchmark_single_process(args):
del model_config["model"]
if args.dry_run:
train(model_config, pipe_model, benchmark_config, args)
train(model_config, pipe_model, benchmark_config, model_specs, args)
else:
benchmark_language_model(model_config, pipe_model, benchmark_config, args)
benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
def run_mp_worker(args, available_workers):
benchmark_config = create_benchmark_config(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config)
model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
......@@ -533,9 +544,9 @@ def run_mp_worker(args, available_workers):
pipe_model.pipeline.all_at_once = True
if args.dry_run:
train(model_config, pipe_model, benchmark_config, args)
train(model_config, pipe_model, benchmark_config, model_specs, args)
else:
benchmark_language_model(model_config, pipe_model, benchmark_config, args)
benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
def run_worker(rank, world_size, args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment