Commit a4e34985 authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

make dictionary optional

Reviewed By: jingfeidu

Differential Revision: D13104360

fbshipit-source-id: 9636f5ee2721818f98b33af559fa24292534a72f
parent 161d1e06
......@@ -51,7 +51,11 @@ class CharacterTokenEmbedder(torch.nn.Module):
self.projection = nn.Linear(last_dim, word_embed_dim)
self.set_vocab(vocab, max_char_len)
assert vocab is not None or char_inputs, "vocab must be set if not using char inputs"
self.vocab = None
if vocab is not None:
self.set_vocab(vocab, max_char_len)
self.reset_parameters()
def set_vocab(self, vocab, max_char_len):
......@@ -78,7 +82,7 @@ class CharacterTokenEmbedder(torch.nn.Module):
@property
def padding_idx(self):
return self.vocab.pad()
return Dictionary().pad() if self.vocab is None else self.vocab.pad()
def reset_parameters(self):
nn.init.xavier_normal_(self.char_embeddings.weight)
......
......@@ -87,11 +87,14 @@ class LanguageModelingTask(FairseqTask):
Args:
args (argparse.Namespace): parsed command-line arguments
"""
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)
dictionary = None
output_dictionary = None
if args.data:
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)
# upgrade old checkpoints
if hasattr(args, 'exclude_self_target'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment