Unverified Commit 9e9f6b8a authored by tiedemann's avatar tiedemann Committed by GitHub
Browse files

Update convert_marian_to_pytorch.py (#16124)

Configuration `tied-embeddings-all` implies `tied-embeddings-src`
parent 2de99e6c
...@@ -480,6 +480,8 @@ class OpusState: ...@@ -480,6 +480,8 @@ class OpusState:
if "Wpos" in self.state_dict: if "Wpos" in self.state_dict:
raise ValueError("Wpos key in state dictionary") raise ValueError("Wpos key in state dictionary")
self.state_dict = dict(self.state_dict) self.state_dict = dict(self.state_dict)
if cfg["tied-embeddings-all"]:
cfg["tied-embeddings-src"] = True
self.share_encoder_decoder_embeddings = cfg["tied-embeddings-src"] self.share_encoder_decoder_embeddings = cfg["tied-embeddings-src"]
# create the tokenizer here because we need to know the eos_token_id # create the tokenizer here because we need to know the eos_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment