Commit 19c25f47 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

Always smaller soft

no need to have half-size option as behavior can be reproduced with existing flags
parent d8998173
......@@ -115,8 +115,6 @@ class FConvLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, metavar='D',
......@@ -145,7 +143,6 @@ class FConvLanguageModel(FairseqLanguageModel):
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
adaptive_softmax_half_size=args.adaptive_softmax_half_size,
normalization_constant=args.normalization_constant,
)
return FConvLanguageModel(decoder)
......@@ -347,7 +344,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, adaptive_softmax_half_size=False, normalization_constant=0.5,
adaptive_softmax_cutoff=None, normalization_constant=0.5,
left_pad=False,
):
super().__init__(dictionary)
......@@ -409,7 +406,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
if adaptive_softmax_cutoff is not None:
assert not share_embed
self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff,
dropout=dropout, half_size=adaptive_softmax_half_size)
dropout=dropout)
else:
self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed:
......@@ -616,7 +613,6 @@ def base_lm_architecture(args):
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
......
......@@ -75,8 +75,6 @@ class TransformerModel(FairseqModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
@classmethod
def build_model(cls, args, task):
......@@ -156,8 +154,6 @@ class TransformerLanguageModel(FairseqLanguageModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-half-size', action='store_true',
help='if set, halves the dimensionality of adaptive softmax (as in original impl)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
......@@ -314,7 +310,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
len(dictionary), args.decoder_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.dropout,
half_size=args.adaptive_softmax_half_size,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
......@@ -579,7 +574,8 @@ def base_lm_architecture(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
args.character_embeddings = getattr(args, 'character_embeddings', False)
# The model training is not stable without this
args.decoder_normalize_before = True
......@@ -627,7 +623,6 @@ def base_architecture(args):
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_half_size = getattr(args, 'adaptive_softmax_half_size', False)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
......
......@@ -18,7 +18,7 @@ class AdaptiveSoftmax(nn.Module):
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
"""
def __init__(self, vocab_size, input_dim, cutoff, dropout, half_size=False):
def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__()
if vocab_size > cutoff[-1]:
......@@ -32,21 +32,11 @@ class AdaptiveSoftmax(nn.Module):
self.vocab_size = vocab_size
self.cutoff = cutoff
self.dropout = dropout
self.input_dim = input_dim
self.lsm = nn.LogSoftmax(dim=1)
self.head = nn.Linear(input_dim, output_dim, bias=False)
self.tail = nn.ModuleList()
extra_denom = 1 if half_size else 0
for i in range(len(cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(input_dim, input_dim // 4 ** (i + extra_denom), bias=False),
nn.Dropout(dropout),
nn.Linear(input_dim // 4 ** (i + extra_denom), cutoff[i + 1] - cutoff[i], bias=False)
)
)
self._make_tail(True)
def init_weights(m):
if hasattr(m, 'weight'):
......@@ -58,10 +48,24 @@ class AdaptiveSoftmax(nn.Module):
# versions prior to 1 had a bug that offset indices on the head by 1
self.buggy_offset = 0
def _make_tail(self, fix_exponent):
extra_denom = 1 if fix_exponent else 0
self.tail = nn.ModuleList()
for i in range(len(self.cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(self.input_dim, self.input_dim // 4 ** (i + extra_denom), bias=False),
nn.Dropout(self.dropout),
nn.Linear(self.input_dim // 4 ** (i + extra_denom), self.cutoff[i + 1] - self.cutoff[i], bias=False)
)
)
def upgrade_state_dict_named(self, state_dict, name):
version_name = name + '.version'
if version_name not in state_dict:
self.buggy_offset = 1
self._make_tail(False)
state_dict[version_name] = torch.LongTensor([1])
def adapt_target(self, target):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment