Commit ed592ab5 authored by Spencer Poff's avatar Spencer Poff Committed by Facebook Github Bot
Browse files

making it easier to use transformer_lm model with new tasks

Summary:
There were two non-obvious errors I ran into while creating a new language modeling task:
- `transformer_lm` implicitly required the `tokens_per_sample` arg
- `transformer_lm` assumed the task had a `dictionary` and `output_dictionary` property, neither of which are specified in the FairseqTask interface

Reviewed By: myleott

Differential Revision: D15532345

fbshipit-source-id: 200d7d3b542c35f17cc2d6bca4219c4a4d17cb6b
parent 4e9ecb80
...@@ -20,6 +20,9 @@ from fairseq.modules import ( ...@@ -20,6 +20,9 @@ from fairseq.modules import (
CharacterTokenEmbedder, CharacterTokenEmbedder,
) )
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer_lm') @register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel): class TransformerLanguageModel(FairseqLanguageModel):
...@@ -101,24 +104,24 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -101,24 +104,24 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.tie_adaptive_proj = True args.tie_adaptive_proj = True
if not hasattr(args, 'max_source_positions'): if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.tokens_per_sample args.max_source_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_SOURCE_POSITIONS)
if not hasattr(args, 'max_target_positions'): if not hasattr(args, 'max_target_positions'):
args.max_target_positions = args.tokens_per_sample args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
if args.character_embeddings: if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder( embed_tokens = CharacterTokenEmbedder(
task.dictionary, eval(args.character_filters), task.source_dictionary, eval(args.character_filters),
args.character_embedding_dim, args.decoder_embed_dim, args.character_embedding_dim, args.decoder_embed_dim,
args.char_embedder_highway_layers, args.char_embedder_highway_layers,
) )
elif args.adaptive_input: elif args.adaptive_input:
embed_tokens = AdaptiveInput( embed_tokens = AdaptiveInput(
len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, len(task.source_dictionary), task.source_dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim, args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int), options.eval_str_list(args.adaptive_input_cutoff, type=int),
) )
else: else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) embed_tokens = Embedding(len(task.source_dictionary), args.decoder_input_dim, task.source_dictionary.pad())
if args.tie_adaptive_weights: if args.tie_adaptive_weights:
assert args.adaptive_input assert args.adaptive_input
...@@ -128,7 +131,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -128,7 +131,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
assert args.decoder_input_dim == args.decoder_output_dim assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder( decoder = TransformerDecoder(
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, args, task.target_dictionary, embed_tokens, no_encoder_attn=True,
final_norm=args.decoder_final_norm, final_norm=args.decoder_final_norm,
) )
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment