Commit 19c25f47 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

Always smaller soft

no need to have half-size option as behavior can be reproduced with existing flags
parent d8998173
...@@ -115,8 +115,6 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -115,8 +115,6 @@ class FConvLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR', parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]') help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, metavar='D', parser.add_argument('--normalization-constant', type=float, metavar='D',
...@@ -145,7 +143,6 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -145,7 +143,6 @@ class FConvLanguageModel(FairseqLanguageModel):
options.eval_str_list(args.adaptive_softmax_cutoff, type=int) options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None if args.criterion == 'adaptive_loss' else None
), ),
adaptive_softmax_half_size=args.adaptive_softmax_half_size,
normalization_constant=args.normalization_constant, normalization_constant=args.normalization_constant,
) )
return FConvLanguageModel(decoder) return FConvLanguageModel(decoder)
...@@ -347,7 +344,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -347,7 +344,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256, self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True, max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True, dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, adaptive_softmax_half_size=False, normalization_constant=0.5, adaptive_softmax_cutoff=None, normalization_constant=0.5,
left_pad=False, left_pad=False,
): ):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -409,7 +406,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -409,7 +406,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
if adaptive_softmax_cutoff is not None: if adaptive_softmax_cutoff is not None:
assert not share_embed assert not share_embed
self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff, self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff,
dropout=dropout, half_size=adaptive_softmax_half_size) dropout=dropout)
else: else:
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed: if share_embed:
...@@ -616,7 +613,6 @@ def base_lm_architecture(args): ...@@ -616,7 +613,6 @@ def base_lm_architecture(args):
args.decoder_attention = getattr(args, 'decoder_attention', 'False') args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.normalization_constant = getattr(args, 'normalization_constant', 0.5) args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103') @register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
......
...@@ -75,8 +75,6 @@ class TransformerModel(FairseqModel): ...@@ -75,8 +75,6 @@ class TransformerModel(FairseqModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'), 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -156,8 +154,6 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -156,8 +154,6 @@ class TransformerLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)') help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
...@@ -314,7 +310,6 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -314,7 +310,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
len(dictionary), args.decoder_embed_dim, len(dictionary), args.decoder_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.dropout, dropout=args.dropout,
half_size=args.adaptive_softmax_half_size,
) )
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
...@@ -579,7 +574,8 @@ def base_lm_architecture(args): ...@@ -579,7 +574,8 @@ def base_lm_architecture(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
args.character_embeddings = getattr(args, 'character_embeddings', False)
# The model training is not stable without this # The model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
...@@ -627,7 +623,6 @@ def base_architecture(args): ...@@ -627,7 +623,6 @@ def base_architecture(args):
args.relu_dropout = getattr(args, 'relu_dropout', 0.) args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
......
...@@ -18,7 +18,7 @@ class AdaptiveSoftmax(nn.Module): ...@@ -18,7 +18,7 @@ class AdaptiveSoftmax(nn.Module):
approximation for GPUs" (http://arxiv.org/abs/1609.04309). approximation for GPUs" (http://arxiv.org/abs/1609.04309).
""" """
def __init__(self, vocab_size, input_dim, cutoff, dropout, half_size=False): def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__() super().__init__()
if vocab_size > cutoff[-1]: if vocab_size > cutoff[-1]:
...@@ -32,21 +32,11 @@ class AdaptiveSoftmax(nn.Module): ...@@ -32,21 +32,11 @@ class AdaptiveSoftmax(nn.Module):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.cutoff = cutoff self.cutoff = cutoff
self.dropout = dropout self.dropout = dropout
self.input_dim = input_dim
self.lsm = nn.LogSoftmax(dim=1) self.lsm = nn.LogSoftmax(dim=1)
self.head = nn.Linear(input_dim, output_dim, bias=False) self.head = nn.Linear(input_dim, output_dim, bias=False)
self.tail = nn.ModuleList() self._make_tail(True)
extra_denom = 1 if half_size else 0
for i in range(len(cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(input_dim, input_dim // 4 ** (i + extra_denom), bias=False),
nn.Dropout(dropout),
nn.Linear(input_dim // 4 ** (i + extra_denom), cutoff[i + 1] - cutoff[i], bias=False)
)
)
def init_weights(m): def init_weights(m):
if hasattr(m, 'weight'): if hasattr(m, 'weight'):
...@@ -58,10 +48,24 @@ class AdaptiveSoftmax(nn.Module): ...@@ -58,10 +48,24 @@ class AdaptiveSoftmax(nn.Module):
# versions prior to 1 had a bug that offset indices on the head by 1 # versions prior to 1 had a bug that offset indices on the head by 1
self.buggy_offset = 0 self.buggy_offset = 0
def _make_tail(self, fix_exponent):
extra_denom = 1 if fix_exponent else 0
self.tail = nn.ModuleList()
for i in range(len(self.cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(self.input_dim, self.input_dim // 4 ** (i + extra_denom), bias=False),
nn.Dropout(self.dropout),
nn.Linear(self.input_dim // 4 ** (i + extra_denom), self.cutoff[i + 1] - self.cutoff[i], bias=False)
)
)
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
version_name = name + '.version' version_name = name + '.version'
if version_name not in state_dict: if version_name not in state_dict:
self.buggy_offset = 1 self.buggy_offset = 1
self._make_tail(False)
state_dict[version_name] = torch.LongTensor([1]) state_dict[version_name] = torch.LongTensor([1])
def adapt_target(self, target): def adapt_target(self, target):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment