Unverified Commit 81841734 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Add batch size to the golden benchmark configs. (#313)



* [refactor]Remove unused variables and refactor common configurations

* move helper function to call site

* fixed lint errors

* fix lint errors

* fix lint errors

* fix lint errors

* fix import order

* format files

* remove unused imports

* fix lint errors

* fix lint errors

* refactor common utilities

* address PR comments

* sorted imports

* add space

* modify comment

* added doc strings and addressed PR comments.

* addressed PR comments

* added another comment to clarify.

* fixing lint errors

* addressed PR comments

* addressed PR comments

* fixed typos

* initialize var

* rename seq_pred to lm

* fix lint errors

* move datasets and models into separate folders

* add the folders created

* fix lint errors

* create golden config to stats mapping

* add common batching for both synthetic and real data

* fixed lint errors

* enable real pipe benchmakrs with new golden data

* reduce seq len to avoid OOM

* updated golden data

* add logging

* add golden data

* add golden data

* fix lint errors

* add doc string

* remove unused class

* add seq len and batch size to the config

* remove commented out line

* address comments

* rename imports

* refactor common logic in dataloaders

* add golden configs

* lint changes

* merge latest changes

* lint errors

* address PR comments
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent b52041d9
...@@ -20,7 +20,11 @@ def _batchify(data, batch_size): ...@@ -20,7 +20,11 @@ def _batchify(data, batch_size):
return data return data
def get_real_dataloaders(args): def _get_total_batch_size(benchmark_config):
return benchmark_config["seq_len"] * benchmark_config["batch_size"]
def get_real_dataloaders(args, benchmark_config):
"""Return real dataloaders for training, testing and validation.""" """Return real dataloaders for training, testing and validation."""
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
...@@ -41,25 +45,21 @@ def get_real_dataloaders(args): ...@@ -41,25 +45,21 @@ def get_real_dataloaders(args):
batch_size = args.batch_size batch_size = args.batch_size
return _batchify(data, batch_size) return _batchify(data, batch_size)
# TODO(anj-s): Both seq_len and batch size should be part of the golden config. total_batch_size = _get_total_batch_size(benchmark_config)
seq_len = 32
total_batch_size = seq_len * args.batch_size
train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify) train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify)
valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify) valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify)
test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify) test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify)
return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader
def get_synthetic_dataloaders(args): def get_synthetic_dataloaders(args, benchmark_config):
"""Return synthetic dataloaders for training, testing and validation.""" """Return synthetic dataloaders for training, testing and validation."""
def batchify(data): def batchify(data):
batch_size = args.batch_size batch_size = args.batch_size
return _batchify(data, batch_size) return _batchify(data, batch_size)
# TODO(anj-s): Both seq_len and batch size should be part of the golden config. total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config)
seq_len = 32
total_batch_size = seq_len * args.batch_size
# vocab_size is 10000 and length of the real data is 2049990. # vocab_size is 10000 and length of the real data is 2049990.
lm_dataset = torch.randint(1, 10000, (2049990,)) lm_dataset = torch.randint(1, 10000, (2049990,))
......
...@@ -20,6 +20,7 @@ def get_benchmark_config(): ...@@ -20,6 +20,7 @@ def get_benchmark_config():
"scaler": GradScaler(), "scaler": GradScaler(),
"clip_value": 0.05, "clip_value": 0.05,
"batch_size": 8, "batch_size": 8,
"seq_len": 32,
} }
......
...@@ -409,43 +409,44 @@ def generate_balance(num_devices, num_layers): ...@@ -409,43 +409,44 @@ def generate_balance(num_devices, num_layers):
return balance return balance
def get_synthetic_dataloader(args): def get_synthetic_dataloaders(args, benchmark_config):
"""Returns dataloader for synthetic data.""" """Returns dataloader for synthetic data."""
if args.model_name == "lm": if args.model_name == "lm":
return get_synthetic_wikitext2_dataloaders(args) return get_synthetic_wikitext2_dataloaders(args, benchmark_config)
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_real_dataloaders(args, device, config): def get_real_dataloaders(args, device, benchmark_config):
"""Returns dataloaders for real data.""" """Returns dataloaders for real data."""
if args.model_name == "lm": if args.model_name == "lm":
data = get_real_wikitext2_dataloaders(args) data = get_real_wikitext2_dataloaders(args, benchmark_config)
ntokens, train_dataloader, valid_dataloader, test_dataloader = data ntokens, train_dataloader, valid_dataloader, test_dataloader = data
config["vocab_size"] = ntokens benchmark_config["vocab_size"] = ntokens
return train_dataloader, valid_dataloader, test_dataloader return train_dataloader, valid_dataloader, test_dataloader
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def create_model_config(args, config=None): def create_model_config(args, benchmark_config=None):
"""Return a dict with the given model, dataset and optimizer.""" """Return a dict with the given model, dataset and optimizer."""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if args.use_synthetic_data: if args.use_synthetic_data:
model, optimizer = get_model_and_optimizer(args, device, config) dataloader_fn = get_synthetic_dataloaders
data = get_synthetic_dataloader(args)
return {"model": model, "optimizer": optimizer, "data": data}
else: else:
data = get_real_dataloaders(args, device, config) dataloader_fn = get_real_dataloaders
model, optimizer = get_model_and_optimizer(args, device, config)
return { data = dataloader_fn(args, device, benchmark_config)
"model": model, model, optimizer = get_model_and_optimizer(args, device, benchmark_config)
"optimizer": optimizer, return {
"data": data, "model": model,
} "optimizer": optimizer,
"data": data,
}
def create_benchmark_config(model_name): def create_benchmark_config(model_name):
...@@ -474,7 +475,7 @@ def benchmark_single_process(args): ...@@ -474,7 +475,7 @@ def benchmark_single_process(args):
init_random_seed(0) init_random_seed(0)
benchmark_config = create_benchmark_config(args.model_name) benchmark_config = create_benchmark_config(args.model_name)
model_config = create_model_config(args, config=benchmark_config) model_config = create_model_config(args, benchmark_config=benchmark_config)
model = model_config["model"] model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model)) balance = generate_balance(min(num_devices, 4), len(model))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment