Commit 1d38624f authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

parameters to separate input/inner/out dims

parent e4f51e18
......@@ -145,6 +145,10 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='dropout probability after ReLU in FFN')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension')
parser.add_argument('--decoder-input-dim', type=int, metavar='N',
help='decoder input dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
......@@ -191,9 +195,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.char_embedder_highway_layers,
)
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad())
print(args)
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
return TransformerLanguageModel(decoder)
......@@ -291,12 +293,19 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
output_embed_dim = args.decoder_output_dim
padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False,
uniform=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad,
......@@ -311,15 +320,18 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, output_embed_dim,
bias=False, uniform=False) if embed_dim != output_embed_dim else None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), args.decoder_embed_dim,
len(dictionary), output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before
if self.normalize:
......@@ -339,6 +351,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -362,6 +378,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
......@@ -555,10 +574,14 @@ def LayerNorm(embedding_dim):
return m
def Linear(in_features, out_features, bias=True):
def Linear(in_features, out_features, bias=True, uniform=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.)
if uniform:
nn.init.xavier_uniform_(m.weight)
else:
nn.init.xavier_normal_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.)
return m
......@@ -584,6 +607,9 @@ def base_lm_architecture(args):
args.character_embeddings = getattr(args, 'character_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
# The model training is not stable without this
args.decoder_normalize_before = True
......@@ -635,6 +661,9 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment