Commit 07e34244 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Decoder embedding sharing in PyTorch Translate for denoising autoencoder (#386)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/386

Pull Request resolved: https://github.com/pytorch/translate/pull/266

This allows decoder embedding sharing for denoising autoencoder modules with different decoders (one for src decoding and one for tgt decoding)

Reviewed By: dpacgopinath

Differential Revision: D13133015

fbshipit-source-id: 3c98be639d705744ccf5ba3a8fd7d10ddc7aef4a
parent a5e2d786
......@@ -4,12 +4,14 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import FairseqDecoder, FairseqEncoder
from fairseq.data import Dictionary
class BaseFairseqModel(nn.Module):
......@@ -191,6 +193,38 @@ class FairseqMultiModel(BaseFairseqModel):
for key in self.keys
})
@staticmethod
def build_shared_embeddings(
dicts: Dict[str, Dictionary],
langs: List[str],
embed_dim: int,
build_embedding: callable,
pretrained_embed_path: Optional[str] = None,
):
"""
Helper function to build shared embeddings for a set of languages after
checking that all dicts corresponding to those languages are equivalent.
Args:
dicts: Dict of lang_id to its corresponding Dictionary
langs: languages that we want to share embeddings for
embed_dim: embedding dimension
build_embedding: callable function to actually build the embedding
pretrained_embed_path: Optional path to load pretrained embeddings
"""
shared_dict = dicts[langs[0]]
if any(dicts[lang] != shared_dict for lang in langs):
raise ValueError(
'--share-*-embeddings requires a joined dictionary: '
'--share-encoder-embeddings requires a joined source '
'dictionary, --share-decoder-embeddings requires a joined '
'target dictionary, and --share-all-embeddings requires a '
'joint source + target dictionary.'
)
return build_embedding(
shared_dict, embed_dim, pretrained_embed_path
)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
decoder_outs = {}
for key in self.keys:
......
......@@ -88,34 +88,41 @@ class MultilingualTransformerModel(FairseqMultiModel):
# build shared embeddings (if applicable)
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
if args.share_all_embeddings:
shared_dict = task.dicts[task.langs[0]]
if any(dict != shared_dict for dict in task.dicts.values()):
raise ValueError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
shared_encoder_embed_tokens = build_embedding(
shared_dict, args.encoder_embed_dim, args.encoder_embed_path
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=task.langs,
embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
)
shared_decoder_embed_tokens = shared_encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
if args.share_encoder_embeddings:
shared_dict = task.dicts[src_langs[0]]
if any(task.dicts[src_lang] != shared_dict for src_lang in src_langs):
raise ValueError('--share-encoder-embeddings requires a joined source dictionary')
shared_encoder_embed_tokens = build_embedding(
shared_dict, args.encoder_embed_dim, args.encoder_embed_path
shared_encoder_embed_tokens = (
FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=src_langs,
embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
)
)
if args.share_decoder_embeddings:
shared_dict = task.dicts[tgt_langs[0]]
if any(task.dicts[tgt_lang] != shared_dict for tgt_lang in src_langs):
raise ValueError('--share-decoder-embeddings requires a joined target dictionary')
shared_decoder_embed_tokens = build_embedding(
shared_dict, args.decoder_embed_dim, args.decoder_embed_path
shared_decoder_embed_tokens = (
FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=tgt_langs,
embed_dim=args.decoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.decoder_embed_path,
)
)
# encoders/decoders for each language
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment