Commit 47f095fd authored by Myle Ott's avatar Myle Ott
Browse files

Remove --normalization-constant from fconv

parent 8c0ca1a0
......@@ -49,8 +49,6 @@ class FConvModel(FairseqModel):
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, metavar='D',
help='multiplies the result of the residual block by sqrt(value)')
parser.add_argument('--share-input-output-embed', action='store_true',
help='share input and output embeddings (requires'
' --decoder-out-embed-dim and --decoder-embed-dim'
......@@ -79,7 +77,6 @@ class FConvModel(FairseqModel):
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_source_positions,
normalization_constant=args.normalization_constant,
)
decoder = FConvDecoder(
dictionary=task.target_dictionary,
......@@ -91,7 +88,6 @@ class FConvModel(FairseqModel):
dropout=args.dropout,
max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed,
normalization_constant=args.normalization_constant,
)
return FConvModel(encoder, decoder)
......@@ -119,8 +115,6 @@ class FConvLanguageModel(FairseqLanguageModel):
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--normalization-constant', type=float, metavar='D',
help='multiplies the result of the residual block by sqrt(value)')
@classmethod
def build_model(cls, args, task):
......@@ -146,7 +140,6 @@ class FConvLanguageModel(FairseqLanguageModel):
if args.criterion == 'adaptive_loss' else None
),
adaptive_softmax_dropout=args.adaptive_softmax_dropout,
normalization_constant=args.normalization_constant,
)
return FConvLanguageModel(decoder)
......@@ -156,12 +149,10 @@ class FConvEncoder(FairseqEncoder):
def __init__(
self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, normalization_constant=0.5,
left_pad=True,
convolutions=((512, 3),) * 20, dropout=0.1, left_pad=True,
):
super().__init__(dictionary)
self.dropout = dropout
self.normalization_constant = normalization_constant
self.left_pad = left_pad
self.num_attention_layers = None
......@@ -247,7 +238,7 @@ class FConvEncoder(FairseqEncoder):
x = F.glu(x, dim=2)
if residual is not None:
x = (x + residual) * math.sqrt(self.normalization_constant)
x = (x + residual) * math.sqrt(0.5)
residuals.append(x)
# T x B x C -> B x T x C
......@@ -264,7 +255,7 @@ class FConvEncoder(FairseqEncoder):
x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(self.normalization_constant)
y = (x + input_embedding) * math.sqrt(0.5)
return {
'encoder_out': (x, y),
......@@ -288,9 +279,8 @@ class FConvEncoder(FairseqEncoder):
class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim, normalization_constant=0.5, bmm=None):
def __init__(self, conv_channels, embed_dim, bmm=None):
super().__init__()
self.normalization_constant = normalization_constant
# projects from output of convolution to embedding dimension
self.in_projection = Linear(conv_channels, embed_dim)
# projects from embedding dimension to convolution size
......@@ -302,7 +292,7 @@ class AttentionLayer(nn.Module):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(self.normalization_constant)
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0])
# don't attend over padding
......@@ -330,7 +320,7 @@ class AttentionLayer(nn.Module):
x = x * (s * s.rsqrt())
# project back
x = (self.out_projection(x) + residual) * math.sqrt(self.normalization_constant)
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
......@@ -347,13 +337,12 @@ class FConvDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, normalization_constant=0.5,
adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0,
left_pad=False,
):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.dropout = dropout
self.normalization_constant = normalization_constant
self.left_pad = left_pad
self.need_attn = True
......@@ -397,7 +386,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
padding=(kernel_size - 1), dropout=dropout)
)
self.attention.append(AttentionLayer(out_channels, embed_dim, self.normalization_constant)
self.attention.append(AttentionLayer(out_channels, embed_dim)
if attention[i] else None)
self.residuals.append(residual)
in_channels = out_channels
......@@ -482,7 +471,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
# residual
if residual is not None:
x = (x + residual) * math.sqrt(self.normalization_constant)
x = (x + residual) * math.sqrt(0.5)
residuals.append(x)
# T x B x C -> B x T x C
......@@ -616,7 +605,6 @@ def base_lm_architecture(args):
args.decoder_attention = getattr(args, 'decoder_attention', 'False')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
......@@ -661,7 +649,6 @@ def base_architecture(args):
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_attention = getattr(args, 'decoder_attention', 'True')
args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
args.normalization_constant = getattr(args, 'normalization_constant', 0.5)
@register_model_architecture('fconv', 'fconv_iwslt_de_en')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment