Unverified Commit 81841734 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Add batch size to the golden benchmark configs. (#313)



* [refactor]Remove unused variables and refactor common configurations

* move helper function to call site

* fixed lint errors

* fix lint errors

* fix lint errors

* fix lint errors

* fix import order

* format files

* remove unused imports

* fix lint errors

* fix lint errors

* refactor common utilities

* address PR comments

* sorted imports

* add space

* modify comment

* added doc strings and addressed PR comments.

* addressed PR comments

* added another comment to clarify.

* fixing lint errors

* addressed PR comments

* addressed PR comments

* fixed typos

* initialize var

* rename seq_pred to lm

* fix lint errors

* move datasets and models into separate folders

* add the folders created

* fix lint errors

* create golden config to stats mapping

* add common batching for both synthetic and real data

* fixed lint errors

* enable real pipe benchmakrs with new golden data

* reduce seq len to avoid OOM

* updated golden data

* add logging

* add golden data

* add golden data

* fix lint errors

* add doc string

* remove unused class

* add seq len and batch size to the config

* remove commented out line

* address comments

* rename imports

* refactor common logic in dataloaders

* add golden configs

* lint changes

* merge latest changes

* lint errors

* address PR comments
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent b52041d9
......@@ -20,7 +20,11 @@ def _batchify(data, batch_size):
return data
def get_real_dataloaders(args):
def _get_total_batch_size(benchmark_config):
return benchmark_config["seq_len"] * benchmark_config["batch_size"]
def get_real_dataloaders(args, benchmark_config):
"""Return real dataloaders for training, testing and validation."""
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
......@@ -41,25 +45,21 @@ def get_real_dataloaders(args):
batch_size = args.batch_size
return _batchify(data, batch_size)
# TODO(anj-s): Both seq_len and batch size should be part of the golden config.
seq_len = 32
total_batch_size = seq_len * args.batch_size
total_batch_size = _get_total_batch_size(benchmark_config)
train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify)
valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify)
test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify)
return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader
def get_synthetic_dataloaders(args):
def get_synthetic_dataloaders(args, benchmark_config):
"""Return synthetic dataloaders for training, testing and validation."""
def batchify(data):
batch_size = args.batch_size
return _batchify(data, batch_size)
# TODO(anj-s): Both seq_len and batch size should be part of the golden config.
seq_len = 32
total_batch_size = seq_len * args.batch_size
total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config)
# vocab_size is 10000 and length of the real data is 2049990.
lm_dataset = torch.randint(1, 10000, (2049990,))
......
......@@ -20,6 +20,7 @@ def get_benchmark_config():
"scaler": GradScaler(),
"clip_value": 0.05,
"batch_size": 8,
"seq_len": 32,
}
......
......@@ -409,43 +409,44 @@ def generate_balance(num_devices, num_layers):
return balance
def get_synthetic_dataloader(args):
def get_synthetic_dataloaders(args, benchmark_config):
"""Returns dataloader for synthetic data."""
if args.model_name == "lm":
return get_synthetic_wikitext2_dataloaders(args)
return get_synthetic_wikitext2_dataloaders(args, benchmark_config)
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_real_dataloaders(args, device, config):
def get_real_dataloaders(args, device, benchmark_config):
"""Returns dataloaders for real data."""
if args.model_name == "lm":
data = get_real_wikitext2_dataloaders(args)
data = get_real_wikitext2_dataloaders(args, benchmark_config)
ntokens, train_dataloader, valid_dataloader, test_dataloader = data
config["vocab_size"] = ntokens
benchmark_config["vocab_size"] = ntokens
return train_dataloader, valid_dataloader, test_dataloader
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def create_model_config(args, config=None):
def create_model_config(args, benchmark_config=None):
"""Return a dict with the given model, dataset and optimizer."""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if args.use_synthetic_data:
model, optimizer = get_model_and_optimizer(args, device, config)
data = get_synthetic_dataloader(args)
return {"model": model, "optimizer": optimizer, "data": data}
dataloader_fn = get_synthetic_dataloaders
else:
data = get_real_dataloaders(args, device, config)
model, optimizer = get_model_and_optimizer(args, device, config)
return {
"model": model,
"optimizer": optimizer,
"data": data,
}
dataloader_fn = get_real_dataloaders
data = dataloader_fn(args, device, benchmark_config)
model, optimizer = get_model_and_optimizer(args, device, benchmark_config)
return {
"model": model,
"optimizer": optimizer,
"data": data,
}
def create_benchmark_config(model_name):
......@@ -474,7 +475,7 @@ def benchmark_single_process(args):
init_random_seed(0)
benchmark_config = create_benchmark_config(args.model_name)
model_config = create_model_config(args, config=benchmark_config)
model_config = create_model_config(args, benchmark_config=benchmark_config)
model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment