Commit 616afddd authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

option for a smaller adaptive softmax

parent f69206c8
...@@ -115,6 +115,8 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -115,6 +115,8 @@ class FConvLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR', parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]') help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, metavar='D', parser.add_argument('--normalization-constant', type=float, metavar='D',
...@@ -143,6 +145,7 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -143,6 +145,7 @@ class FConvLanguageModel(FairseqLanguageModel):
options.eval_str_list(args.adaptive_softmax_cutoff, type=int) options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None if args.criterion == 'adaptive_loss' else None
), ),
adaptive_softmax_half_size=args.adaptive_softmax_half_size,
normalization_constant=args.normalization_constant, normalization_constant=args.normalization_constant,
) )
return FConvLanguageModel(decoder) return FConvLanguageModel(decoder)
...@@ -152,9 +155,9 @@ class FConvEncoder(FairseqEncoder): ...@@ -152,9 +155,9 @@ class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__( def __init__(
self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024, self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, normalization_constant=0.5, convolutions=((512, 3),) * 20, dropout=0.1, normalization_constant=0.5,
left_pad=True, left_pad=True,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = dropout self.dropout = dropout
...@@ -341,11 +344,11 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -341,11 +344,11 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__( def __init__(
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True, max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True, dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, normalization_constant=0.5, adaptive_softmax_cutoff=None, adaptive_softmax_half_size=False, normalization_constant=0.5,
left_pad=False, left_pad=False,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
...@@ -406,7 +409,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -406,7 +409,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
if adaptive_softmax_cutoff is not None: if adaptive_softmax_cutoff is not None:
assert not share_embed assert not share_embed
self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff, self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff,
dropout=dropout) dropout=dropout, half_size=adaptive_softmax_half_size)
else: else:
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed: if share_embed:
...@@ -613,6 +616,7 @@ def base_lm_architecture(args): ...@@ -613,6 +616,7 @@ def base_lm_architecture(args):
args.decoder_attention = getattr(args, 'decoder_attention', 'False') args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.normalization_constant = getattr(args, 'normalization_constant', 0.5) args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103') @register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
......
...@@ -73,7 +73,9 @@ class TransformerModel(FairseqModel): ...@@ -73,7 +73,9 @@ class TransformerModel(FairseqModel):
' (requires shared dictionary and embed dim)') ' (requires shared dictionary and embed dim)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -153,6 +155,8 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -153,6 +155,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)') help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
...@@ -292,7 +296,8 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -292,7 +296,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.adaptive_softmax = AdaptiveSoftmax( self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), args.decoder_embed_dim, len(dictionary), args.decoder_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.dropout dropout=args.dropout,
half_size=args.adaptive_softmax_half_size,
) )
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
...@@ -557,6 +562,7 @@ def base_lm_architecture(args): ...@@ -557,6 +562,7 @@ def base_lm_architecture(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
# The model training is not stable without this # The model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
...@@ -604,6 +610,7 @@ def base_architecture(args): ...@@ -604,6 +610,7 @@ def base_architecture(args):
args.relu_dropout = getattr(args, 'relu_dropout', 0.) args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
......
...@@ -18,7 +18,7 @@ class AdaptiveSoftmax(nn.Module): ...@@ -18,7 +18,7 @@ class AdaptiveSoftmax(nn.Module):
approximation for GPUs" (http://arxiv.org/abs/1609.04309). approximation for GPUs" (http://arxiv.org/abs/1609.04309).
""" """
def __init__(self, vocab_size, input_dim, cutoff, dropout): def __init__(self, vocab_size, input_dim, cutoff, dropout, half_size=False):
super().__init__() super().__init__()
if vocab_size > cutoff[-1]: if vocab_size > cutoff[-1]:
...@@ -37,12 +37,14 @@ class AdaptiveSoftmax(nn.Module): ...@@ -37,12 +37,14 @@ class AdaptiveSoftmax(nn.Module):
self.head = nn.Linear(input_dim, output_dim, bias=False) self.head = nn.Linear(input_dim, output_dim, bias=False)
self.tail = nn.ModuleList() self.tail = nn.ModuleList()
extra_denom = 1 if half_size else 0
for i in range(len(cutoff) - 1): for i in range(len(cutoff) - 1):
self.tail.append( self.tail.append(
nn.Sequential( nn.Sequential(
nn.Linear(input_dim, input_dim // 4 ** i, bias=False), nn.Linear(input_dim, input_dim // 4 ** (i + extra_denom), bias=False),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(input_dim // 4 ** i, cutoff[i + 1] - cutoff[i], bias=False) nn.Linear(input_dim // 4 ** (i + extra_denom), cutoff[i + 1] - cutoff[i], bias=False)
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment