Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8831c688
Commit
8831c688
authored
Jan 16, 2019
by
thomwolf
Browse files
fixing various parts of model conversion, loading and weights sharing
parent
bcd4aa8f
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
243 additions
and
283 deletions
+243
-283
examples/eval_transfo_xl.py
examples/eval_transfo_xl.py
+1
-1
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+3
-2
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+233
-280
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+6
-0
No files found.
examples/eval_transfo_xl.py
View file @
8831c688
...
@@ -42,7 +42,7 @@ parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model
...
@@ -42,7 +42,7 @@ parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# help='location of the data corpus')
# help='location of the data corpus')
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'transfo-xl-wt103'
,
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'transfo-xl-wt103'
,
choices
=
[
'transfo-xl-wt103'
],
#, 'lm1b', 'enwik8', 'text8'],
#
choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
help
=
'pretrained model name'
)
help
=
'pretrained model name'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'test'
,
parser
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'test'
,
choices
=
[
'all'
,
'valid'
,
'test'
],
choices
=
[
'all'
,
'valid'
,
'test'
],
...
...
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
8831c688
...
@@ -116,7 +116,8 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...
@@ -116,7 +116,8 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_NAME
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_NAME
print
(
"Save vocabulary to {}"
.
format
(
pytorch_vocab_dump_path
))
print
(
"Save vocabulary to {}"
.
format
(
pytorch_vocab_dump_path
))
torch
.
save
(
corpus
.
vocab
.
__dict__
,
pytorch_vocab_dump_path
)
corpus_vocab_dict
=
corpus
.
vocab
.
__dict__
torch
.
save
(
corpus_vocab_dict
,
pytorch_vocab_dump_path
)
corpus_dict_no_vocab
=
corpus
.
__dict__
corpus_dict_no_vocab
=
corpus
.
__dict__
corpus_dict_no_vocab
.
pop
(
'vocab'
,
None
)
corpus_dict_no_vocab
.
pop
(
'vocab'
,
None
)
...
@@ -139,7 +140,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...
@@ -139,7 +140,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
model
=
TransfoXLModel
(
config
)
model
=
TransfoXLModel
(
config
)
# Build TF to PyTorch weights loading map
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
.
transformer
,
config
)
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
,
config
)
# Load weights from TF model
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
8831c688
This diff is collapsed.
Click to expand it.
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
8831c688
...
@@ -444,6 +444,12 @@ class TransfoXLCorpus(object):
...
@@ -444,6 +444,12 @@ class TransfoXLCorpus(object):
for
key
,
value
in
corpus_dict
.
items
():
for
key
,
value
in
corpus_dict
.
items
():
corpus
.
__dict__
[
key
]
=
value
corpus
.
__dict__
[
key
]
=
value
corpus
.
vocab
=
vocab
corpus
.
vocab
=
vocab
if
corpus
.
train
is
not
None
:
corpus
.
train
=
torch
.
tensor
(
corpus
.
train
,
dtype
=
torch
.
long
)
if
corpus
.
valid
is
not
None
:
corpus
.
valid
=
torch
.
tensor
(
corpus
.
valid
,
dtype
=
torch
.
long
)
if
corpus
.
test
is
not
None
:
corpus
.
test
=
torch
.
tensor
(
corpus
.
test
,
dtype
=
torch
.
long
)
return
corpus
return
corpus
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment