"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b83bdce42abb60e72314bb8507c710e6f649dfb1"
Commit a4e34985 authored by Haoran Li's avatar Haoran Li Committed by Facebook Github Bot
Browse files

make dictionary optional

Reviewed By: jingfeidu

Differential Revision: D13104360

fbshipit-source-id: 9636f5ee2721818f98b33af559fa24292534a72f
parent 161d1e06
...@@ -51,7 +51,11 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -51,7 +51,11 @@ class CharacterTokenEmbedder(torch.nn.Module):
self.projection = nn.Linear(last_dim, word_embed_dim) self.projection = nn.Linear(last_dim, word_embed_dim)
self.set_vocab(vocab, max_char_len) assert vocab is not None or char_inputs, "vocab must be set if not using char inputs"
self.vocab = None
if vocab is not None:
self.set_vocab(vocab, max_char_len)
self.reset_parameters() self.reset_parameters()
def set_vocab(self, vocab, max_char_len): def set_vocab(self, vocab, max_char_len):
...@@ -78,7 +82,7 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -78,7 +82,7 @@ class CharacterTokenEmbedder(torch.nn.Module):
@property @property
def padding_idx(self): def padding_idx(self):
return self.vocab.pad() return Dictionary().pad() if self.vocab is None else self.vocab.pad()
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_normal_(self.char_embeddings.weight) nn.init.xavier_normal_(self.char_embeddings.weight)
......
...@@ -87,11 +87,14 @@ class LanguageModelingTask(FairseqTask): ...@@ -87,11 +87,14 @@ class LanguageModelingTask(FairseqTask):
Args: Args:
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
""" """
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) dictionary = None
print('| dictionary: {} types'.format(len(dictionary))) output_dictionary = None
output_dictionary = dictionary if args.data:
if args.output_dictionary_size >= 0: dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size) print('| dictionary: {} types'.format(len(dictionary)))
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)
# upgrade old checkpoints # upgrade old checkpoints
if hasattr(args, 'exclude_self_target'): if hasattr(args, 'exclude_self_target'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment