Commit 1d38624f authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

parameters to separate input/inner/out dims

parent e4f51e18
...@@ -145,6 +145,10 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -145,6 +145,10 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='dropout probability after ReLU in FFN') help='dropout probability after ReLU in FFN')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension') help='decoder embedding dimension')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension')
parser.add_argument('--decoder-input-dim', type=int, metavar='N',
help='decoder input dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN') help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N', parser.add_argument('--decoder-layers', type=int, metavar='N',
...@@ -191,9 +195,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -191,9 +195,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.char_embedder_highway_layers, args.char_embedder_highway_layers,
) )
else: else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad()) embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
print(args)
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True) decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
...@@ -291,12 +293,19 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -291,12 +293,19 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
output_embed_dim = args.decoder_output_dim
padding_idx = embed_tokens.padding_idx padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False,
uniform=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, padding_idx, args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad, left_pad=left_pad,
...@@ -311,15 +320,18 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -311,15 +320,18 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.adaptive_softmax = None self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, output_embed_dim,
bias=False, uniform=False) if embed_dim != output_embed_dim else None
if args.adaptive_softmax_cutoff is not None: if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax( self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), args.decoder_embed_dim, len(dictionary), output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout, dropout=args.adaptive_softmax_dropout,
) )
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before self.normalize = args.decoder_normalize_before
if self.normalize: if self.normalize:
...@@ -339,6 +351,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -339,6 +351,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens) x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None: if positions is not None:
x += positions x += positions
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -362,6 +378,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -362,6 +378,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# T x B x C -> B x T x C # T x B x C -> B x T x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if self.adaptive_softmax is None: if self.adaptive_softmax is None:
# project back to size of vocabulary # project back to size of vocabulary
if self.share_input_output_embed: if self.share_input_output_embed:
...@@ -555,10 +574,14 @@ def LayerNorm(embedding_dim): ...@@ -555,10 +574,14 @@ def LayerNorm(embedding_dim):
return m return m
def Linear(in_features, out_features, bias=True): def Linear(in_features, out_features, bias=True, uniform=True):
m = nn.Linear(in_features, out_features, bias) m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight) if uniform:
nn.init.constant_(m.bias, 0.) nn.init.xavier_uniform_(m.weight)
else:
nn.init.xavier_normal_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.)
return m return m
...@@ -584,6 +607,9 @@ def base_lm_architecture(args): ...@@ -584,6 +607,9 @@ def base_lm_architecture(args):
args.character_embeddings = getattr(args, 'character_embeddings', False) args.character_embeddings = getattr(args, 'character_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
# The model training is not stable without this # The model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
...@@ -635,6 +661,9 @@ def base_architecture(args): ...@@ -635,6 +661,9 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
@register_model_architecture('transformer', 'transformer_iwslt_de_en') @register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args): def transformer_iwslt_de_en(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment