Commit 3afcba6e authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Work batch-size name changes into task code

parent 5c45db4a
...@@ -134,9 +134,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -134,9 +134,12 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.lr_decay_samples is None, \ assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay' 'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \ assert args.lr_warmup_samples == 0, \
'expected iteration-based learnig rate warmup' 'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \ assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training' 'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_percent is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-iters'
# Sample-based training. # Sample-based training.
if args.train_samples: if args.train_samples:
...@@ -148,11 +151,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -148,11 +151,14 @@ def parse_args(extra_args_provider=None, defaults={},
'expected sample-based learning rate decay' 'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \ assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup' 'expected sample-based learnig rate warmup'
if args.lr_warmup_percent is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-samples'
# Check required arguments. # Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads', required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings'] 'max_position_embeddings']
for req_arg in required_args: for req_arg in required_args:
_check_arg_is_not_none(args, req_arg) _check_arg_is_not_none(args, req_arg)
# Checks. # Checks.
...@@ -353,6 +359,9 @@ def _add_learning_rate_args(parser): ...@@ -353,6 +359,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-samples', type=int, default=None, group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,' help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`') ' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-percent', type=float, default=None,
help='percentage of lr-warmup-(iters/samples) to use '
'for warmup')
group.add_argument('--lr-warmup-iters', type=int, default=0, group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup ' help='number of iterations to linearly warmup '
'learning rate over.') 'learning rate over.')
...@@ -568,4 +577,3 @@ def _add_realm_args(parser): ...@@ -568,4 +577,3 @@ def _add_realm_args(parser):
group.add_argument('--indexer-log-interval', type=int, default=1000, group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress') help='After how many batches should the indexer report progress')
return parser return parser
...@@ -36,7 +36,7 @@ def get_batch(context_tokens): ...@@ -36,7 +36,7 @@ def get_batch(context_tokens):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# Move to GPU. # Move to GPU.
tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda() tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids. # Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids( attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
...@@ -294,7 +294,7 @@ def generate_samples_unconditional(model): ...@@ -294,7 +294,7 @@ def generate_samples_unconditional(model):
num_samples = args.num_samples num_samples = args.num_samples
context_tokens = [[tokenizer.eod] context_tokens = [[tokenizer.eod]
for _ in range(args.batch_size)] for _ in range(args.micro_batch_size)]
ctr = 0 ctr = 0
while True: while True:
start_time = time.time() start_time = time.time()
...@@ -310,7 +310,7 @@ def generate_samples_unconditional(model): ...@@ -310,7 +310,7 @@ def generate_samples_unconditional(model):
length = len(token_stream) length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
assert len(length_batch) == args.batch_size assert len(length_batch) == args.micro_batch_size
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length - 1] tokens = tokens[1:length - 1]
text = tokenizer.detokenize(tokens) text = tokenizer.detokenize(tokens)
...@@ -321,7 +321,7 @@ def generate_samples_unconditional(model): ...@@ -321,7 +321,7 @@ def generate_samples_unconditional(model):
if ctr >= num_samples: if ctr >= num_samples:
break break
else: else:
for _ in range(args.batch_size): for _ in range(args.micro_batch_size):
yield None yield None
ctr += 1 ctr += 1
if ctr >= num_samples: if ctr >= num_samples:
......
...@@ -223,18 +223,24 @@ def get_learning_rate_scheduler(optimizer): ...@@ -223,18 +223,24 @@ def get_learning_rate_scheduler(optimizer):
if args.train_iters: if args.train_iters:
if args.lr_decay_iters is None: if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters args.lr_decay_iters = args.train_iters
warmup_steps = args.lr_warmup_iters * args.global_batch_size
decay_steps = args.lr_decay_iters * args.global_batch_size decay_steps = args.lr_decay_iters * args.global_batch_size
if args.lr_warmup_percent is not None:
warmup_steps = args.lr_warmup_percent * decay_steps
else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training. # Sample-based training.
elif args.train_samples: elif args.train_samples:
# We need to set training iters for later use. Technically # We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last # we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now. # batch being incomplete) but we leave it as is for now.
update_train_iters(args) update_train_iters(args)
if args.lr_decay_samples is None: if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples args.lr_decay_samples = args.train_samples
warmup_steps = args.lr_warmup_samples
decay_steps = args.lr_decay_samples decay_steps = args.lr_decay_samples
if args.lr_warmup_percent is not None:
warmup_steps = args.lr_warmup_percent * decay_steps
else:
warmup_steps = args.lr_warmup_samples
else: else:
raise Exception( raise Exception(
'either train-iters or train-samples should be provided.') 'either train-iters or train-samples should be provided.')
......
...@@ -80,7 +80,7 @@ def calculate_correct_answers(name, model, dataloader, ...@@ -80,7 +80,7 @@ def calculate_correct_answers(name, model, dataloader,
args = get_args() args = get_args()
start_time = time.time() start_time = time.time()
model.eval() model.eval()
saved_batch_size = args.batch_size saved_batch_size = args.micro_batch_size
with torch.no_grad(): with torch.no_grad():
# For all the batches in the dataset. # For all the batches in the dataset.
total = 0 total = 0
...@@ -103,7 +103,7 @@ def calculate_correct_answers(name, model, dataloader, ...@@ -103,7 +103,7 @@ def calculate_correct_answers(name, model, dataloader,
ds = dataloader.dataset ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'): if hasattr(ds, 'sample_multiplier'):
actual_batch_size *= ds.sample_multiplier actual_batch_size *= ds.sample_multiplier
args.batch_size = actual_batch_size args.micro_batch_size = actual_batch_size
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate( input_tensor, _ = communicate(
...@@ -145,7 +145,7 @@ def calculate_correct_answers(name, model, dataloader, ...@@ -145,7 +145,7 @@ def calculate_correct_answers(name, model, dataloader,
recv_backward=False) recv_backward=False)
model.train() model.train()
args.batch_size = saved_batch_size args.micro_batch_size = saved_batch_size
# Reduce. # Reduce.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
......
...@@ -129,10 +129,10 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -129,10 +129,10 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
args.num_workers, not args.keep_last) args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set args.batch_size to # Now that we've built the data loaders, set args.micro_batch_size to
# the actual batch size the model will see for this dataset # the actual batch size the model will see for this dataset
if hasattr(train_dataset, 'sample_multiplier'): if hasattr(train_dataset, 'sample_multiplier'):
args.batch_size *= train_dataset.sample_multiplier args.micro_batch_size *= train_dataset.sample_multiplier
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment