Unverified Commit b89365e6 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

split benchmark configs (#420)

parent 8876553e
...@@ -21,11 +21,11 @@ def _batchify(data, batch_size): ...@@ -21,11 +21,11 @@ def _batchify(data, batch_size):
return data return data
def _get_total_batch_size(benchmark_config): def _get_total_batch_size(benchmark_config, model_specs):
return benchmark_config["seq_len"] * benchmark_config["batch_size"] return model_specs["seq_len"] * benchmark_config["batch_size"]
def get_real_dataloaders(args, benchmark_config): def get_real_dataloaders(args, benchmark_config, model_specs):
"""Return real dataloaders for training, testing and validation.""" """Return real dataloaders for training, testing and validation."""
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
...@@ -47,21 +47,21 @@ def get_real_dataloaders(args, benchmark_config): ...@@ -47,21 +47,21 @@ def get_real_dataloaders(args, benchmark_config):
batch_size = args.batch_size batch_size = args.batch_size
return _batchify(data, batch_size) return _batchify(data, batch_size)
total_batch_size = _get_total_batch_size(benchmark_config) total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify) train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify)
valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify) valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify)
test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify) test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify)
return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader
def get_synthetic_dataloaders(args, benchmark_config): def get_synthetic_dataloaders(args, benchmark_config, model_specs):
"""Return synthetic dataloaders for training, testing and validation.""" """Return synthetic dataloaders for training, testing and validation."""
def batchify(data): def batchify(data):
batch_size = args.batch_size batch_size = args.batch_size
return _batchify(data, batch_size) return _batchify(data, batch_size)
total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config) total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
# vocab_size is 10000 and length of the real data is 2049990. # vocab_size is 10000 and length of the real data is 2049990.
lm_dataset = torch.randint(1, 10000, (2049990,)) lm_dataset = torch.randint(1, 10000, (2049990,))
......
...@@ -5,26 +5,31 @@ import torch.nn as nn ...@@ -5,26 +5,31 @@ import torch.nn as nn
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
def get_benchmark_config(): def get_model_config():
return { return {
"epochs": 1,
"vocab_size": 10000, "vocab_size": 10000,
"ninp": 2048, # embedding dimension "ninp": 2048, # embedding dimension
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder "nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models "nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0, "dropout": 0,
"initrange": 0.1, "initrange": 0.1,
"criterion": nn.CrossEntropyLoss(),
"lr": 0.001, # learning rate
"scaler": GradScaler(), "scaler": GradScaler(),
"clip_value": 0.05, "clip_value": 0.05,
"batch_size": 8,
"num_decoder_layers": 10, "num_decoder_layers": 10,
"seq_len": 32, "seq_len": 32,
} }
def get_benchmark_config():
return {
"epochs": 1,
"lr": 0.001, # learning rate
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
}
def get_golden_real_stats(multiprocess=False): def get_golden_real_stats(multiprocess=False):
if not multiprocess: if not multiprocess:
return { return {
......
...@@ -40,13 +40,13 @@ def init_random_seed(seed: int): ...@@ -40,13 +40,13 @@ def init_random_seed(seed: int):
np.random.seed(seed) np.random.seed(seed)
def get_model_and_optimizer(args, device, config): def get_model_and_optimizer(args, device, benchmark_config, model_config):
"""Return instantiated model and optimizer function.""" """Return instantiated model and optimizer function."""
if args.model_name == "lm": if args.model_name == "lm":
model = get_lm_model(args, device, config) model = get_lm_model(args, device, model_config)
lr = config["lr"] lr = benchmark_config["lr"]
def make_adam(params): def make_adam(params):
if args.ddp_zero: if args.ddp_zero:
...@@ -201,10 +201,10 @@ def get_fake_dataloader(lm_dataloader_len, args): ...@@ -201,10 +201,10 @@ def get_fake_dataloader(lm_dataloader_len, args):
return FakeDataset() return FakeDataset()
def train(model_config, model, benchmark_config, args): def train(model_config, model, benchmark_config, model_specs, args):
lm_dataloader, _, _ = model_config["data"] lm_dataloader, _, _ = model_config["data"]
criterion = benchmark_config["criterion"] criterion = benchmark_config["criterion"]
vocab_size = benchmark_config["vocab_size"] vocab_size = model_specs["vocab_size"]
optimizer = model_config["optimizer"] optimizer = model_config["optimizer"]
model.train() model.train()
...@@ -227,7 +227,7 @@ def train(model_config, model, benchmark_config, args): ...@@ -227,7 +227,7 @@ def train(model_config, model, benchmark_config, args):
# TODO(anj-s): Avoid sending fake data to all replicas except the first and last one. # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1): if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config) lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs)
total_tokens = 0 total_tokens = 0
total_tokens_per_log_interval = 0 total_tokens_per_log_interval = 0
...@@ -281,7 +281,7 @@ def train(model_config, model, benchmark_config, args): ...@@ -281,7 +281,7 @@ def train(model_config, model, benchmark_config, args):
del output del output
torch.nn.utils.clip_grad_value_(model.parameters(), benchmark_config["clip_value"]) torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
optimizer.step() optimizer.step()
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
...@@ -374,7 +374,7 @@ def verify_lm_run(wps, golden_config, args): ...@@ -374,7 +374,7 @@ def verify_lm_run(wps, golden_config, args):
verify_peak_memory(i, golden_config, 1.1) verify_peak_memory(i, golden_config, 1.1)
def benchmark_language_model(model_config, model, benchmark_config, args): def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
golden_config = get_golden_config(args.model_name, args) golden_config = get_golden_config(args.model_name, args)
epoch = benchmark_config["epochs"] epoch = benchmark_config["epochs"]
start_time = time.time() start_time = time.time()
...@@ -382,7 +382,7 @@ def benchmark_language_model(model_config, model, benchmark_config, args): ...@@ -382,7 +382,7 @@ def benchmark_language_model(model_config, model, benchmark_config, args):
print("-" * 110) print("-" * 110)
print("| start of epoch {:1d}".format(epoch)) print("| start of epoch {:1d}".format(epoch))
print("-" * 110) print("-" * 110)
wps, loss = train(model_config, model, benchmark_config, args) wps, loss = train(model_config, model, benchmark_config, model_specs, args)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
print("-" * 110) print("-" * 110)
...@@ -427,28 +427,28 @@ def generate_balance(num_devices, num_layers): ...@@ -427,28 +427,28 @@ def generate_balance(num_devices, num_layers):
return balance return balance
def get_synthetic_dataloaders(args, benchmark_config): def get_synthetic_dataloaders(args, benchmark_config, model_specs):
"""Returns dataloader for synthetic data.""" """Returns dataloader for synthetic data."""
if args.model_name == "lm": if args.model_name == "lm":
return get_synthetic_wikitext2_dataloaders(args, benchmark_config) return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_real_dataloaders(args, device, benchmark_config): def get_real_dataloaders(args, device, benchmark_config, model_specs):
"""Returns dataloaders for real data.""" """Returns dataloaders for real data."""
if args.model_name == "lm": if args.model_name == "lm":
data = get_real_wikitext2_dataloaders(args, benchmark_config) data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
ntokens, train_dataloader, valid_dataloader, test_dataloader = data ntokens, train_dataloader, valid_dataloader, test_dataloader = data
benchmark_config["vocab_size"] = ntokens model_specs["vocab_size"] = ntokens
return train_dataloader, valid_dataloader, test_dataloader return train_dataloader, valid_dataloader, test_dataloader
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def create_model_config(args, benchmark_config=None): def create_model_config(args, benchmark_config=None, model_specs=None):
"""Return a dict with the given model, dataset and optimizer.""" """Return a dict with the given model, dataset and optimizer."""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
...@@ -458,8 +458,8 @@ def create_model_config(args, benchmark_config=None): ...@@ -458,8 +458,8 @@ def create_model_config(args, benchmark_config=None):
else: else:
dataloader_fn = get_real_dataloaders dataloader_fn = get_real_dataloaders
data = dataloader_fn(args, device, benchmark_config) data = dataloader_fn(args, device, benchmark_config, model_specs)
model, optimizer = get_model_and_optimizer(args, device, benchmark_config) model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
return { return {
"model": model, "model": model,
"optimizer": optimizer, "optimizer": optimizer,
...@@ -476,6 +476,15 @@ def create_benchmark_config(model_name): ...@@ -476,6 +476,15 @@ def create_benchmark_config(model_name):
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_model_specs(model_name):
"""Return a dict with configurations required for configuring `model_name` model."""
if model_name == "lm":
return lm_wikitext2.get_model_config()
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_golden_config(model_name, args): def get_golden_config(model_name, args):
"""Return a dict with the golden data for throughput and memory usage.""" """Return a dict with the golden data for throughput and memory usage."""
...@@ -496,7 +505,8 @@ def benchmark_single_process(args): ...@@ -496,7 +505,8 @@ def benchmark_single_process(args):
init_random_seed(0) init_random_seed(0)
benchmark_config = create_benchmark_config(args.model_name) benchmark_config = create_benchmark_config(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config) model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"] model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model)) balance = generate_balance(min(num_devices, 4), len(model))
...@@ -505,15 +515,16 @@ def benchmark_single_process(args): ...@@ -505,15 +515,16 @@ def benchmark_single_process(args):
del model_config["model"] del model_config["model"]
if args.dry_run: if args.dry_run:
train(model_config, pipe_model, benchmark_config, args) train(model_config, pipe_model, benchmark_config, model_specs, args)
else: else:
benchmark_language_model(model_config, pipe_model, benchmark_config, args) benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
def run_mp_worker(args, available_workers): def run_mp_worker(args, available_workers):
benchmark_config = create_benchmark_config(args.model_name) benchmark_config = create_benchmark_config(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config) model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"] model = model_config["model"]
balance = generate_balance(get_pipeline_parallel_group().size(), len(model)) balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
...@@ -533,9 +544,9 @@ def run_mp_worker(args, available_workers): ...@@ -533,9 +544,9 @@ def run_mp_worker(args, available_workers):
pipe_model.pipeline.all_at_once = True pipe_model.pipeline.all_at_once = True
if args.dry_run: if args.dry_run:
train(model_config, pipe_model, benchmark_config, args) train(model_config, pipe_model, benchmark_config, model_specs, args)
else: else:
benchmark_language_model(model_config, pipe_model, benchmark_config, args) benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
def run_worker(rank, world_size, args): def run_worker(rank, world_size, args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment