"src/vscode:/vscode.git/clone" did not exist on "2fd46405cd4e845e65b102acc8849667ab508790"
Commit 6e3685ad authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

make adaptive softmax dropout an optional arg

parent 19c25f47
...@@ -115,6 +115,8 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -115,6 +115,8 @@ class FConvLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR', parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]') help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, metavar='D', parser.add_argument('--normalization-constant', type=float, metavar='D',
...@@ -143,6 +145,7 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -143,6 +145,7 @@ class FConvLanguageModel(FairseqLanguageModel):
options.eval_str_list(args.adaptive_softmax_cutoff, type=int) options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None if args.criterion == 'adaptive_loss' else None
), ),
adaptive_softmax_dropout=args.adaptive_softmax_dropout,
normalization_constant=args.normalization_constant, normalization_constant=args.normalization_constant,
) )
return FConvLanguageModel(decoder) return FConvLanguageModel(decoder)
...@@ -344,7 +347,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -344,7 +347,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True, max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True, dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, normalization_constant=0.5, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, normalization_constant=0.5,
left_pad=False, left_pad=False,
): ):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -406,7 +409,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -406,7 +409,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
if adaptive_softmax_cutoff is not None: if adaptive_softmax_cutoff is not None:
assert not share_embed assert not share_embed
self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff, self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff,
dropout=dropout) dropout=adaptive_softmax_dropout)
else: else:
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed: if share_embed:
...@@ -612,6 +615,7 @@ def base_lm_architecture(args): ...@@ -612,6 +615,7 @@ def base_lm_architecture(args):
args.decoder_layers = getattr(args, 'decoder_layers', '[(1268, 4)] * 13') args.decoder_layers = getattr(args, 'decoder_layers', '[(1268, 4)] * 13')
args.decoder_attention = getattr(args, 'decoder_attention', 'False') args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.normalization_constant = getattr(args, 'normalization_constant', 0.5) args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
......
...@@ -75,6 +75,8 @@ class TransformerModel(FairseqModel): ...@@ -75,6 +75,8 @@ class TransformerModel(FairseqModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'), 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -154,6 +156,8 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -154,6 +156,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)') help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
...@@ -309,7 +313,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -309,7 +313,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.adaptive_softmax = AdaptiveSoftmax( self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), args.decoder_embed_dim, len(dictionary), args.decoder_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.dropout, dropout=args.adaptive_softmax_dropout,
) )
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
...@@ -573,6 +577,7 @@ def base_lm_architecture(args): ...@@ -573,6 +577,7 @@ def base_lm_architecture(args):
args.decoder_layers = getattr(args, 'decoder_layers', 6) args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.character_embeddings = getattr(args, 'character_embeddings', False) args.character_embeddings = getattr(args, 'character_embeddings', False)
...@@ -623,6 +628,7 @@ def base_architecture(args): ...@@ -623,6 +628,7 @@ def base_architecture(args):
args.relu_dropout = getattr(args, 'relu_dropout', 0.) args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment