Commit 8c03ff2d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix positions for LM

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/622

Differential Revision: D15572555

Pulled By: myleott

fbshipit-source-id: 2b81f22207b4c894ffe645af0b45c70ac0a80612
parent 8ca05802
......@@ -27,6 +27,9 @@ from fairseq.modules import (
SinusoidalPositionalEmbedding,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer')
class TransformerModel(FairseqEncoderDecoderModel):
......@@ -112,9 +115,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
......
......@@ -20,7 +20,6 @@ from fairseq.modules import (
CharacterTokenEmbedder,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
......@@ -103,9 +102,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
# backward compatibility
args.tie_adaptive_proj = True
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_SOURCE_POSITIONS)
if not hasattr(args, 'max_target_positions'):
if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
if args.character_embeddings:
......
......@@ -81,7 +81,7 @@ class LanguageModelingTask(FairseqTask):
help='include past target')
parser.add_argument('--add-bos-token', action='store_true',
help='prepend beginning of sentence token (<s>)')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
parser.add_argument('--max-target-positions', type=int, metavar='N',
help='max number of tokens in the target sequence')
# fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment