Commit 8831c688 authored by thomwolf's avatar thomwolf
Browse files

fixing various parts of model conversion, loading and weights sharing

parent bcd4aa8f
......@@ -42,7 +42,7 @@ parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# help='location of the data corpus')
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
# choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
help='pretrained model name')
parser.add_argument('--split', type=str, default='test',
choices=['all', 'valid', 'test'],
......
......@@ -116,7 +116,8 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
torch.save(corpus.vocab.__dict__, pytorch_vocab_dump_path)
corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
corpus_dict_no_vocab = corpus.__dict__
corpus_dict_no_vocab.pop('vocab', None)
......@@ -139,7 +140,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
model = TransfoXLModel(config)
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
......
......@@ -444,6 +444,12 @@ class TransfoXLCorpus(object):
for key, value in corpus_dict.items():
corpus.__dict__[key] = value
corpus.vocab = vocab
if corpus.train is not None:
corpus.train = torch.tensor(corpus.train, dtype=torch.long)
if corpus.valid is not None:
corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)
if corpus.test is not None:
corpus.test = torch.tensor(corpus.test, dtype=torch.long)
return corpus
def __init__(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment