Commit 752eeae3 authored by Mohammad's avatar Mohammad
Browse files

code runs

parent a6ba254f
...@@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens): ...@@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens):
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor, context_length_tensor,
...@@ -469,7 +469,7 @@ def main(): ...@@ -469,7 +469,7 @@ def main():
# Generate samples. # Generate samples.
if args.num_samples == 0: if args.num_samples == 0:
assert args.batch_size == 1 args.batch_size = 1
if args.sample_input_file != "": if args.sample_input_file != "":
generate_samples_input_from_file(model) generate_samples_input_from_file(model)
else: else:
......
...@@ -69,8 +69,10 @@ def parse_args(extra_args_provider=None, defaults={}): ...@@ -69,8 +69,10 @@ def parse_args(extra_args_provider=None, defaults={}):
# Checks. # Checks.
assert args.hidden_size % args.num_attention_heads == 0 assert args.hidden_size % args.num_attention_heads == 0
assert args.max_position_embeddings >= args.seq_length if args.seq_length is not None:
assert args.min_lr <= args.lr assert args.max_position_embeddings >= args.seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None: if args.save is not None:
assert args.save_interval is not None assert args.save_interval is not None
...@@ -134,7 +136,7 @@ def _add_regularization_args(parser): ...@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
def _add_training_args(parser): def _add_training_args(parser):
group = parser.add_argument_group(title='training') group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, required=True, group.add_argument('--batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). ' help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data ' 'Global batch size is local batch size times data '
'parallel size.') 'parallel size.')
...@@ -301,7 +303,7 @@ def _add_data_args(parser): ...@@ -301,7 +303,7 @@ def _add_data_args(parser):
help='Path to the vocab file.') help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.') help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, required=True, group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.") help="Maximum sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15, group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.') help='Probability of replacing a token with mask.')
...@@ -356,10 +358,6 @@ def _add_gpt2_args(parser): ...@@ -356,10 +358,6 @@ def _add_gpt2_args(parser):
def add_data_args_(parser): def add_data_args_(parser):
"""Train/valid/test data arguments.""" """Train/valid/test data arguments."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment