Commit 3afcba6e authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Work batch-size name changes into task code

parent 5c45db4a
......@@ -134,9 +134,12 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learnig rate warmup'
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_percent is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
......@@ -148,11 +151,14 @@ def parse_args(extra_args_provider=None, defaults={},
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_percent is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-samples'
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
for req_arg in required_args:
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
# Checks.
......@@ -353,6 +359,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-percent', type=float, default=None,
help='percentage of lr-warmup-(iters/samples) to use '
'for warmup')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
......@@ -568,4 +577,3 @@ def _add_realm_args(parser):
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress')
return parser
......@@ -36,7 +36,7 @@ def get_batch(context_tokens):
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
......@@ -294,7 +294,7 @@ def generate_samples_unconditional(model):
num_samples = args.num_samples
context_tokens = [[tokenizer.eod]
for _ in range(args.batch_size)]
for _ in range(args.micro_batch_size)]
ctr = 0
while True:
start_time = time.time()
......@@ -310,7 +310,7 @@ def generate_samples_unconditional(model):
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
assert len(length_batch) == args.batch_size
assert len(length_batch) == args.micro_batch_size
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens)
......@@ -321,7 +321,7 @@ def generate_samples_unconditional(model):
if ctr >= num_samples:
break
else:
for _ in range(args.batch_size):
for _ in range(args.micro_batch_size):
yield None
ctr += 1
if ctr >= num_samples:
......
......@@ -223,18 +223,24 @@ def get_learning_rate_scheduler(optimizer):
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
warmup_steps = args.lr_warmup_iters * args.global_batch_size
decay_steps = args.lr_decay_iters * args.global_batch_size
if args.lr_warmup_percent is not None:
warmup_steps = args.lr_warmup_percent * decay_steps
else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
warmup_steps = args.lr_warmup_samples
decay_steps = args.lr_decay_samples
if args.lr_warmup_percent is not None:
warmup_steps = args.lr_warmup_percent * decay_steps
else:
warmup_steps = args.lr_warmup_samples
else:
raise Exception(
'either train-iters or train-samples should be provided.')
......
......@@ -80,7 +80,7 @@ def calculate_correct_answers(name, model, dataloader,
args = get_args()
start_time = time.time()
model.eval()
saved_batch_size = args.batch_size
saved_batch_size = args.micro_batch_size
with torch.no_grad():
# For all the batches in the dataset.
total = 0
......@@ -103,7 +103,7 @@ def calculate_correct_answers(name, model, dataloader,
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
actual_batch_size *= ds.sample_multiplier
args.batch_size = actual_batch_size
args.micro_batch_size = actual_batch_size
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
......@@ -145,7 +145,7 @@ def calculate_correct_answers(name, model, dataloader,
recv_backward=False)
model.train()
args.batch_size = saved_batch_size
args.micro_batch_size = saved_batch_size
# Reduce.
if mpu.is_pipeline_last_stage():
......
......@@ -129,10 +129,10 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set args.batch_size to
# Now that we've built the data loaders, set args.micro_batch_size to
# the actual batch size the model will see for this dataset
if hasattr(train_dataset, 'sample_multiplier'):
args.batch_size *= train_dataset.sample_multiplier
args.micro_batch_size *= train_dataset.sample_multiplier
return train_dataloader, valid_dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment