"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7639b5e70cb599845bf30fa0944d5b46f99ef9ed"
Commit 0693c351 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add check that --encoder-layers matches --decoder-layers for LSTM (fixes #394)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/398

Differential Revision: D13358876

Pulled By: myleott

fbshipit-source-id: 57673f2643aac01492cb8f5728bb9f1a34ba6aa7
parent 50591a29
...@@ -70,6 +70,12 @@ class BaseFairseqModel(nn.Module): ...@@ -70,6 +70,12 @@ class BaseFairseqModel(nn.Module):
self.upgrade_state_dict_named(state_dict, '') self.upgrade_state_dict_named(state_dict, '')
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade old state dicts to work with newer code.
Args:
state_dict (dict): state dictionary to upgrade, in place
name (str): the state dict key corresponding to the current module
"""
assert state_dict is not None assert state_dict is not None
def do_upgrade(m, prefix): def do_upgrade(m, prefix):
......
...@@ -75,6 +75,9 @@ class LSTMModel(FairseqModel): ...@@ -75,6 +75,9 @@ class LSTMModel(FairseqModel):
# make sure that all args are properly defaulted (in case there are any new ones) # make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args) base_architecture(args)
if args.encoder_layers != args.decoder_layers:
raise ValueError('--encoder-layers must match --decoder-layers')
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment