Commit 752eeae3 authored by Mohammad's avatar Mohammad
Browse files

code runs

parent a6ba254f
......@@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens):
group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
......@@ -469,7 +469,7 @@ def main():
# Generate samples.
if args.num_samples == 0:
assert args.batch_size == 1
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model)
else:
......
......@@ -69,8 +69,10 @@ def parse_args(extra_args_provider=None, defaults={}):
# Checks.
assert args.hidden_size % args.num_attention_heads == 0
assert args.max_position_embeddings >= args.seq_length
assert args.min_lr <= args.lr
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
......@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, required=True,
group.add_argument('--batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size.')
......@@ -301,7 +303,7 @@ def _add_data_args(parser):
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, required=True,
group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
......@@ -356,10 +358,6 @@ def _add_gpt2_args(parser):
def add_data_args_(parser):
"""Train/valid/test data arguments."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment