Unverified Commit 44b9bcd8 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Enable benchmarks/pipe.py and merge real and synthetic input pipeline. (#286)



* [refactor]Remove unused variables and refactor common configurations

* move helper function to call site

* fixed lint errors

* fix lint errors

* fix lint errors

* fix lint errors

* fix import order

* format files

* remove unused imports

* fix lint errors

* fix lint errors

* refactor common utilities

* address PR comments

* sorted imports

* add space

* modify comment

* added doc strings and addressed PR comments.

* addressed PR comments

* added another comment to clarify.

* fixing lint errors

* addressed PR comments

* addressed PR comments

* fixed typos

* initialize var

* rename seq_pred to lm

* fix lint errors

* move datasets and models into separate folders

* add the folders created

* fix lint errors

* create golden config to stats mapping

* add common batching for both synthetic and real data

* fixed lint errors

* enable real pipe benchmakrs with new golden data

* reduce seq len to avoid OOM

* updated golden data

* add logging

* add golden data

* add golden data

* fix lint errors

* add doc string

* remove commented out line

* address comments

* rename imports

* refactor common logic in dataloaders

* add golden configs

* lint changes
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 8d710c82
import torch
from torch.utils.data import Dataset
def collate_sentences_lm(samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = torch.stack([s["source"] for s in samples], 0)
tgt_tokens = torch.stack([s["target"] for s in samples], 0)
ntokens = len(samples) * len(samples[0]["target"])
src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"input": src_tokens,
"target": tgt_tokens,
}
return batch
class BenchmarkLMDataset(Dataset):
"""
Dataset to benchmark a translation like seq2seq task.
Args:
vocab_size (int, optional): size of the vocabulary (default 10000).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
total_samples (int, optional): the total number of rows in the
dataset (default: 10000).
"""
def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
):
self.vocab_size = vocab_size
self.max_source_positions = max_source_positions
self.total_samples = total_samples
self.sizes = [self.max_source_positions] * self.total_samples
def __getitem__(self, index):
length = self.sizes[index]
source = torch.randint(1, self.vocab_size, (length,))
target = source.clone()
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return self.total_samples
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
import torchtext
from torchtext.data.utils import get_tokenizer
def get_wikitext2_data(device):
"""Return batched data from wikitext2 dataset for training, validation and testing."""
with warnings.catch_warnings(record=True) as _:
text_field = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(text_field)
text_field.build_vocab(train_txt)
ntokens = len(text_field.vocab.stoi)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size, text_field, device)
val_data = batchify(val_txt, eval_batch_size, text_field, device)
test_data = batchify(test_txt, eval_batch_size, text_field, device)
return ntokens, train_data, val_data, test_data
def batchify(data, bsz, text_field, device):
"""Return batched data that is placed on the specified device."""
data = text_field.numericalize([data.examples[0].text])
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import io
import torch
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
from torchtext.vocab import build_vocab_from_iterator
def _batchify(data, batch_size):
data = torch.tensor(data)
# Divide the dataset into bsz parts.
nbatch = data.size(0) // batch_size
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * batch_size)
# Evenly divide the data across the bsz batches.
data = data.view(batch_size, -1).t().contiguous()
return data
def get_real_dataloaders(args):
"""Return real dataloaders for training, testing and validation."""
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root="/tmp"))
tokenizer = get_tokenizer("basic_english")
def data_process(raw_text_iter):
data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding="utf8"))))
train_dataset = data_process(iter(io.open(train_filepath, encoding="utf8")))
valid_dataset = data_process(iter(io.open(valid_filepath, encoding="utf8")))
test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8")))
def batchify(data):
batch_size = args.batch_size
return _batchify(data, batch_size)
# TODO(anj-s): Both seq_len and batch size should be part of the golden config.
seq_len = 32
total_batch_size = seq_len * args.batch_size
train_dataloader = DataLoader(train_dataset, batch_size=total_batch_size, collate_fn=batchify)
valid_dataloader = DataLoader(valid_dataset, batch_size=total_batch_size, collate_fn=batchify)
test_dataloader = DataLoader(test_dataset, batch_size=total_batch_size, collate_fn=batchify)
return len(vocab.stoi), train_dataloader, valid_dataloader, test_dataloader
def get_synthetic_dataloaders(args):
"""Return synthetic dataloaders for training, testing and validation."""
def batchify(data):
batch_size = args.batch_size
return _batchify(data, batch_size)
# TODO(anj-s): Both seq_len and batch size should be part of the golden config.
seq_len = 32
total_batch_size = seq_len * args.batch_size
# vocab_size is 10000 and length of the real data is 2049990.
lm_dataset = torch.randint(1, 10000, (2049990,))
lm_dataloader = DataLoader(
lm_dataset, batch_size=total_batch_size, shuffle=True, num_workers=0, collate_fn=batchify
)
return lm_dataloader, lm_dataloader, lm_dataloader
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch.nn as nn
from fairscale.optim import GradScaler
def get_benchmark_config():
return {
"epochs": 1,
"vocab_size": 10000,
"ninp": 2048, # embedding dimension
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0,
"initrange": 0.1,
"criterion": nn.CrossEntropyLoss(),
"lr": 0.001, # learning rate
"scaler": GradScaler(),
"clip_value": 0.05,
"batch_size": 8,
}
def get_golden_real_stats():
return {
"avg_wps": 703.778,
"std_dev_wps": 5.732,
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
}
def get_golden_synthetic_stats():
# TODO(anj-s): Add support for synthetic regression benchmarks
raise NotImplementedError("Synthetic data benchmarks are not supported.")
...@@ -72,8 +72,8 @@ class LinearLayer(nn.Linear): ...@@ -72,8 +72,8 @@ class LinearLayer(nn.Linear):
self.weight.data.uniform_(-initrange, initrange) self.weight.data.uniform_(-initrange, initrange)
class TransformerLMSequntial(nn.Sequential): class TransformerLM(nn.Sequential):
"""A GPT-2 based nn.Sequeitnal language model.""" """A GPT-2 based nn.Sequential language model."""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
layers = [ layers = [
...@@ -84,4 +84,4 @@ class TransformerLMSequntial(nn.Sequential): ...@@ -84,4 +84,4 @@ class TransformerLMSequntial(nn.Sequential):
layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange)) layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequntial, self).__init__(*layers) super(TransformerLM, self).__init__(*layers)
...@@ -11,34 +11,24 @@ import os ...@@ -11,34 +11,24 @@ import os
import pprint import pprint
import time import time
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
import datasets from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
import models from golden_configs import lm_wikitext2
from models import transformer_lm
import numpy as np import numpy as np
import torch import torch
from torch.distributed import rpc from torch.distributed import rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.optim import Adam
from fairscale.nn import Pipe from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe from fairscale.nn.pipe import LazyModule, pipe
from fairscale.optim import GradScaler
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
try:
from fairscale.optim import Adam # type: ignore
can_benchmark = True
except ImportError:
from torch.optim import Adam # type: ignore
can_benchmark = False
def init_random_seed(seed: int): def init_random_seed(seed: int):
...@@ -78,16 +68,16 @@ def get_lm_model(args, device, config): ...@@ -78,16 +68,16 @@ def get_lm_model(args, device, config):
if args.lazy_construction: if args.lazy_construction:
layers = [ layers = [
LazyModule(lambda: models.EmbeddingLayer(vocab_size, ninp, initrange)), LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
LazyModule(lambda: models.PositionalEncodingLayer(ninp, dropout)), LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
] ]
for _ in range(ndecoder): for _ in range(ndecoder):
layers.append(LazyModule(lambda: models.TransformerDecoderLayer(ninp, nhead, nhid, dropout))) layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
layers.append(LazyModule(lambda: models.LinearLayer(ninp, vocab_size, initrange))) layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
model = layers model = layers
else: else:
model = models.TransformerLMSequntial(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
return model return model
...@@ -208,17 +198,16 @@ def get_fake_dataloader(lm_dataloader_len): ...@@ -208,17 +198,16 @@ def get_fake_dataloader(lm_dataloader_len):
return FakeDataset() return FakeDataset()
def train(data_config, model, benchmark_config, args): def train(model_config, model, benchmark_config, args):
lm_dataloader = data_config["data"] lm_dataloader, _, _ = model_config["data"]
criterion = benchmark_config["criterion"] criterion = benchmark_config["criterion"]
vocab_size = benchmark_config["vocab_size"] vocab_size = benchmark_config["vocab_size"]
optimizer = data_config["optimizer"] optimizer = model_config["optimizer"]
model.train() model.train()
log_number_of_parameters(model) log_number_of_parameters(model)
total_loss = 0.0 total_loss = 0.0
start_time = time.time()
word_counter = 0 word_counter = 0
optimizer = optimizer(model.parameters()) optimizer = optimizer(model.parameters())
...@@ -239,23 +228,39 @@ def train(data_config, model, benchmark_config, args): ...@@ -239,23 +228,39 @@ def train(data_config, model, benchmark_config, args):
total_tokens = 0 total_tokens = 0
total_tokens_per_log_interval = 0 total_tokens_per_log_interval = 0
bptt = 2
start_time = time.time()
epoch_start_time = 0.0
def get_batch(source):
seq_len = len(source) - 1
data = source[0:seq_len]
target = source[1 : 1 + seq_len]
return data, target
for i, batch in enumerate(lm_dataloader): for i, batch in enumerate(lm_dataloader):
if i == 1:
epoch_start_time = time.time()
source, target = get_batch(batch)
if args.max_batch and i > args.max_batch: if args.max_batch and i > args.max_batch:
break break
total_tokens += batch["input"].numel()
if i > 0:
total_tokens += source.numel()
optimizer.zero_grad() optimizer.zero_grad()
try: try:
if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero: if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
tmp = batch["input"].to(get_device(model, 0)) tmp = source.to(get_device(model, 0))
output = model(tmp) output = model(tmp)
else: else:
output = model(batch["input"]) output = model(source)
except Exception as e: except Exception as e:
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = batch["target"].to(get_device(model, -1)) target = target.to(get_device(model, -1))
output = output.to(target.device) output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1)) loss = criterion(output.view(-1, vocab_size), target.view(-1))
...@@ -279,7 +284,7 @@ def train(data_config, model, benchmark_config, args): ...@@ -279,7 +284,7 @@ def train(data_config, model, benchmark_config, args):
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
total_loss += loss.item() total_loss += loss.item()
log_interval = 1 log_interval = 1
total_tokens_per_log_interval += batch["input"].numel() total_tokens_per_log_interval += source.numel()
if i % log_interval == 0 and i > 0: if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval cur_loss = total_loss / log_interval
elapsed = time.time() - start_time elapsed = time.time() - start_time
...@@ -292,7 +297,14 @@ def train(data_config, model, benchmark_config, args): ...@@ -292,7 +297,14 @@ def train(data_config, model, benchmark_config, args):
total_loss = 0 total_loss = 0
start_time = time.time() start_time = time.time()
return total_tokens, loss.item() if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time)
else:
raise RuntimeError(
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
)
return wps, loss.item()
# TODO(anj-s): Add an option for users to be able to benchmark evaluate. # TODO(anj-s): Add an option for users to be able to benchmark evaluate.
...@@ -322,43 +334,52 @@ def get_number_of_words(data): ...@@ -322,43 +334,52 @@ def get_number_of_words(data):
return data.size()[0] * data.size()[1] return data.size()[0] * data.size()[1]
def verify_lm_run(wps): def verify_lm_run(wps, golden_config):
"""Verify that words per second for a given benchmark run matches the golden data.""" """Verify that words per second for a given benchmark run matches the golden data."""
# Assert that words per second is within 3 standard deviations of the average # Assert that words per second is within 3 standard deviations of the average
# of six golden runs # of five golden runs
assert wps > 36954.4 - (3 * 116.825) print("Throughput(wps) is {:.2f}.".format(wps))
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
raise RuntimeError(
"Throughput(wps):{:.2f} is below the golden threshold of an "
"average value of {:.2f} and standard dev of {:.2f}.".format(
wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
)
)
for i in range(4): for i in range(4):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(i)["allocated_bytes.all.peak"])) print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(i)["allocated_bytes.all.peak"]))
# Assert that memory usage on each GPU is within 10% of golden run # Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110% # Right-hand-side is golden run bytes * 110%
for i, golden_ref in zip(range(4), [4061909504, 4050944, 10427392, 2031824896]): for i, golden_ref in zip(range(4), golden_config["peak_mem_usage"]):
assert torch.cuda.memory_stats(i)["allocated_bytes.all.peak"] < golden_ref * 1.1 current_device_usage = torch.cuda.memory_stats(i)["allocated_bytes.all.peak"]
if not current_device_usage < golden_ref * 1.1:
raise RuntimeError(
"Peak memory usage for cuda device {:d} is {:d} which"
"is less than golden reference value of {:d}".format(i, current_device_usage, golden_ref)
)
def benchmark_language_model(model_config, model, benchmark_config, args): def benchmark_language_model(model_config, model, benchmark_config, args):
ntokens, train_data, val_data, test_data = model_config["data"] golden_config = get_golden_config(args.model_name)
optimizer = model_config["optimizer"] epoch = benchmark_config["epochs"]
criterion = benchmark_config["criterion"]
epoch = 1
print("-" * 110) print("-" * 110)
print("| start of epoch {:1d}".format(epoch)) print("| start of epoch {:1d}".format(epoch))
print("-" * 110) print("-" * 110)
start_time = time.time() start_time = time.time()
n_words, loss = train(data_config, model, benchmark_config, args) wps, loss = train(model_config, model, benchmark_config, args)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
wps = nwords / elapsed_time
print("-" * 110) print("-" * 110)
print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss)) print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
print("-" * 110) print("-" * 110)
if can_benchmark and len(model.balance) == 4: print("wps ", wps)
if len(model.balance) == 4:
if args.model_name == "lm": if args.model_name == "lm":
verify_lm_run(wps) verify_lm_run(wps, golden_config)
else: else:
raise RuntimeError("Unrecognized args.model_name " % args.model_name) raise RuntimeError("Unrecognized args.model_name " % args.model_name)
...@@ -392,23 +413,19 @@ def get_synthetic_dataloader(args): ...@@ -392,23 +413,19 @@ def get_synthetic_dataloader(args):
"""Returns dataloader for synthetic data.""" """Returns dataloader for synthetic data."""
if args.model_name == "lm": if args.model_name == "lm":
lm_dataset = BenchmarkLMDataset() return get_synthetic_wikitext2_dataloaders(args)
lm_dataloader = DataLoader(
lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
)
return lm_dataloader
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def get_real_dataloaders(device, config): def get_real_dataloaders(args, device, config):
"""Returns dataloaders for real data.""" """Returns dataloaders for real data."""
if args.model_name == "lm": if args.model_name == "lm":
data = datasets.get_wikitext2_data(device) data = get_real_wikitext2_dataloaders(args)
ntokens, _, _, _ = data ntokens, train_dataloader, valid_dataloader, test_dataloader = data
config["vocab_size"] = ntokens config["vocab_size"] = ntokens
return data return train_dataloader, valid_dataloader, test_dataloader
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
...@@ -419,10 +436,10 @@ def create_model_config(args, config=None): ...@@ -419,10 +436,10 @@ def create_model_config(args, config=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if args.use_synthetic_data: if args.use_synthetic_data:
model, optimizer = get_model_and_optimizer(args, device, config) model, optimizer = get_model_and_optimizer(args, device, config)
dataloader = get_synthetic_dataloader(args) data = get_synthetic_dataloader(args)
return {"model": model, "optimizer": optimizer, "data": dataloader} return {"model": model, "optimizer": optimizer, "data": data}
else: else:
data = get_real_dataloaders(device, config) data = get_real_dataloaders(args, device, config)
model, optimizer = get_model_and_optimizer(args, device, config) model, optimizer = get_model_and_optimizer(args, device, config)
return { return {
"model": model, "model": model,
...@@ -435,18 +452,16 @@ def create_benchmark_config(model_name): ...@@ -435,18 +452,16 @@ def create_benchmark_config(model_name):
"""Return a dict with configurations required for benchmarking `model_name` model.""" """Return a dict with configurations required for benchmarking `model_name` model."""
if model_name == "lm": if model_name == "lm":
return { return lm_wikitext2.get_benchmark_config()
"vocab_size": 10000, else:
"ninp": 2048, # embedding dimension raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0, def get_golden_config(model_name):
"initrange": 0.1, """Return a dict with the golden data for throughput and memory usage."""
"criterion": nn.CrossEntropyLoss(),
"lr": 0.01, # learning rate if model_name == "lm":
"scaler": GradScaler(), return lm_wikitext2.get_golden_real_stats()
"clip_value": 0.05,
}
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
...@@ -469,7 +484,7 @@ def benchmark_single_process(args): ...@@ -469,7 +484,7 @@ def benchmark_single_process(args):
del model del model
del model_config["model"] del model_config["model"]
if args.use_synthetic_data: if args.dry_run:
train(model_config, pipe_model, benchmark_config, args) train(model_config, pipe_model, benchmark_config, args)
else: else:
benchmark_language_model(model_config, pipe_model, benchmark_config, args) benchmark_language_model(model_config, pipe_model, benchmark_config, args)
...@@ -605,7 +620,8 @@ parser.add_argument( ...@@ -605,7 +620,8 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass" "--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
) )
parser.add_argument("--use_synthetic_data", default=True, help="Uses synthetic data for a sample training run.") parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument( parser.add_argument(
# TODO(anj-s): In the process of adding more models and hence the requirement for a flag. # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
"--model_name", "--model_name",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment