"docs/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "9a37498cfc34181de969697960b1092f22ae721c"
Commit 07e34244 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Decoder embedding sharing in PyTorch Translate for denoising autoencoder (#386)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/386

Pull Request resolved: https://github.com/pytorch/translate/pull/266

This allows decoder embedding sharing for denoising autoencoder modules with different decoders (one for src decoding and one for tgt decoding)

Reviewed By: dpacgopinath

Differential Revision: D13133015

fbshipit-source-id: 3c98be639d705744ccf5ba3a8fd7d10ddc7aef4a
parent a5e2d786
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from typing import Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from . import FairseqDecoder, FairseqEncoder from . import FairseqDecoder, FairseqEncoder
from fairseq.data import Dictionary
class BaseFairseqModel(nn.Module): class BaseFairseqModel(nn.Module):
...@@ -191,6 +193,38 @@ class FairseqMultiModel(BaseFairseqModel): ...@@ -191,6 +193,38 @@ class FairseqMultiModel(BaseFairseqModel):
for key in self.keys for key in self.keys
}) })
@staticmethod
def build_shared_embeddings(
dicts: Dict[str, Dictionary],
langs: List[str],
embed_dim: int,
build_embedding: callable,
pretrained_embed_path: Optional[str] = None,
):
"""
Helper function to build shared embeddings for a set of languages after
checking that all dicts corresponding to those languages are equivalent.
Args:
dicts: Dict of lang_id to its corresponding Dictionary
langs: languages that we want to share embeddings for
embed_dim: embedding dimension
build_embedding: callable function to actually build the embedding
pretrained_embed_path: Optional path to load pretrained embeddings
"""
shared_dict = dicts[langs[0]]
if any(dicts[lang] != shared_dict for lang in langs):
raise ValueError(
'--share-*-embeddings requires a joined dictionary: '
'--share-encoder-embeddings requires a joined source '
'dictionary, --share-decoder-embeddings requires a joined '
'target dictionary, and --share-all-embeddings requires a '
'joint source + target dictionary.'
)
return build_embedding(
shared_dict, embed_dim, pretrained_embed_path
)
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens):
decoder_outs = {} decoder_outs = {}
for key in self.keys: for key in self.keys:
......
...@@ -88,34 +88,41 @@ class MultilingualTransformerModel(FairseqMultiModel): ...@@ -88,34 +88,41 @@ class MultilingualTransformerModel(FairseqMultiModel):
# build shared embeddings (if applicable) # build shared embeddings (if applicable)
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
if args.share_all_embeddings: if args.share_all_embeddings:
shared_dict = task.dicts[task.langs[0]]
if any(dict != shared_dict for dict in task.dicts.values()):
raise ValueError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim: if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError( raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and ( if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path): args.decoder_embed_path != args.encoder_embed_path):
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path') raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
shared_encoder_embed_tokens = build_embedding( shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
shared_dict, args.encoder_embed_dim, args.encoder_embed_path dicts=task.dicts,
langs=task.langs,
embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
) )
shared_decoder_embed_tokens = shared_encoder_embed_tokens shared_decoder_embed_tokens = shared_encoder_embed_tokens
args.share_decoder_input_output_embed = True args.share_decoder_input_output_embed = True
else: else:
if args.share_encoder_embeddings: if args.share_encoder_embeddings:
shared_dict = task.dicts[src_langs[0]] shared_encoder_embed_tokens = (
if any(task.dicts[src_lang] != shared_dict for src_lang in src_langs): FairseqMultiModel.build_shared_embeddings(
raise ValueError('--share-encoder-embeddings requires a joined source dictionary') dicts=task.dicts,
shared_encoder_embed_tokens = build_embedding( langs=src_langs,
shared_dict, args.encoder_embed_dim, args.encoder_embed_path embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
)
) )
if args.share_decoder_embeddings: if args.share_decoder_embeddings:
shared_dict = task.dicts[tgt_langs[0]] shared_decoder_embed_tokens = (
if any(task.dicts[tgt_lang] != shared_dict for tgt_lang in src_langs): FairseqMultiModel.build_shared_embeddings(
raise ValueError('--share-decoder-embeddings requires a joined target dictionary') dicts=task.dicts,
shared_decoder_embed_tokens = build_embedding( langs=tgt_langs,
shared_dict, args.decoder_embed_dim, args.decoder_embed_path embed_dim=args.decoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.decoder_embed_path,
)
) )
# encoders/decoders for each language # encoders/decoders for each language
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment