Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a69ec2c7
Commit
a69ec2c7
authored
Jan 15, 2019
by
thomwolf
Browse files
improved corpus and tokenization conversion - added evaluation script
parent
7d03c537
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
349 additions
and
124 deletions
+349
-124
examples/eval_transfo_xl.py
examples/eval_transfo_xl.py
+151
-0
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+2
-0
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+105
-72
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+23
-17
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+68
-35
No files found.
examples/eval_transfo_xl.py
0 → 100644
View file @
a69ec2c7
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
"""
import
os
import
sys
import
functools
import
argparse
import
time
import
math
import
torch
from
pytorch_pretrained_bert
import
TransfoXLModel
,
TransfoXLCorpus
def
logging
(
s
,
log_path
,
print_
=
True
,
log_
=
True
):
if
print_
:
print
(
s
)
if
log_
:
with
open
(
log_path
,
'a+'
)
as
f_log
:
f_log
.
write
(
s
+
'
\n
'
)
def
get_logger
(
log_path
,
**
kwargs
):
return
functools
.
partial
(
logging
,
log_path
=
log_path
,
**
kwargs
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# help='location of the data corpus')
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'transfo-xl-wt103'
,
choices
=
[
'transfo-xl-wt103'
],
#, 'lm1b', 'enwik8', 'text8'],
help
=
'pretrained model name'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'all'
,
choices
=
[
'all'
,
'valid'
,
'test'
],
help
=
'which split to evaluate'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
10
,
help
=
'batch size'
)
parser
.
add_argument
(
'--tgt_len'
,
type
=
int
,
default
=
5
,
help
=
'number of tokens to predict'
)
parser
.
add_argument
(
'--ext_len'
,
type
=
int
,
default
=
0
,
help
=
'length of the extended context'
)
parser
.
add_argument
(
'--mem_len'
,
type
=
int
,
default
=
0
,
help
=
'length of the retained previous heads'
)
parser
.
add_argument
(
'--clamp_len'
,
type
=
int
,
default
=-
1
,
help
=
'max positional embedding index'
)
parser
.
add_argument
(
'--cuda'
,
action
=
'store_true'
,
help
=
'use CUDA'
)
parser
.
add_argument
(
'--work_dir'
,
type
=
str
,
required
=
True
,
help
=
'path to the work_dir'
)
parser
.
add_argument
(
'--no_log'
,
action
=
'store_true'
,
help
=
'do not log the eval result'
)
parser
.
add_argument
(
'--same_length'
,
action
=
'store_true'
,
help
=
'set same length attention with masking'
)
args
=
parser
.
parse_args
()
assert
args
.
ext_len
>=
0
,
'extended context length must be non-negative'
device
=
torch
.
device
(
"cuda"
if
args
.
cuda
else
"cpu"
)
# Get logger
logging
=
get_logger
(
os
.
path
.
join
(
args
.
work_dir
,
'log.txt'
),
log_
=
not
args
.
no_log
)
# Load dataset
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
ntokens
=
len
(
corpus
.
vocab
)
va_iter
=
corpus
.
get_iterator
(
'valid'
,
args
.
batch_size
,
args
.
tgt_len
,
device
=
device
,
ext_len
=
args
.
ext_len
)
te_iter
=
corpus
.
get_iterator
(
'test'
,
args
.
batch_size
,
args
.
tgt_len
,
device
=
device
,
ext_len
=
args
.
ext_len
)
# Load the best saved model.
# with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
# model = torch.load(f)
# model.backward_compatible()
model
=
TransfoXLModel
.
from_pretrained
(
args
.
model_name
)
model
=
model
.
to
(
device
)
logging
(
'Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'
.
format
(
args
.
batch_size
,
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
,
args
.
clamp_len
))
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
if
args
.
clamp_len
>
0
:
model
.
clamp_len
=
args
.
clamp_len
if
args
.
same_length
:
model
.
same_length
=
True
###############################################################################
# Evaluation code
###############################################################################
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
model
.
eval
()
total_len
,
total_loss
=
0
,
0.
start_time
=
time
.
time
()
with
torch
.
no_grad
():
mems
=
tuple
()
for
idx
,
(
data
,
target
,
seq_len
)
in
enumerate
(
eval_iter
):
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
item
()
total_len
+=
seq_len
total_time
=
time
.
time
()
-
start_time
logging
(
'Time : {:.2f}s, {:.2f}ms/segment'
.
format
(
total_time
,
1000
*
total_time
/
(
idx
+
1
)))
return
total_loss
/
total_len
# Run on test data.
if
args
.
split
==
'all'
:
test_loss
=
evaluate
(
te_iter
)
valid_loss
=
evaluate
(
va_iter
)
elif
args
.
split
==
'valid'
:
valid_loss
=
evaluate
(
va_iter
)
test_loss
=
None
elif
args
.
split
==
'test'
:
test_loss
=
evaluate
(
te_iter
)
valid_loss
=
None
def
format_log
(
loss
,
split
):
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
log_str
=
'| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '
.
format
(
split
,
loss
,
loss
/
math
.
log
(
2
))
else
:
log_str
=
'| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '
.
format
(
split
,
loss
,
math
.
exp
(
loss
))
return
log_str
log_str
=
''
if
valid_loss
is
not
None
:
log_str
+=
format_log
(
valid_loss
,
'valid'
)
if
test_loss
is
not
None
:
log_str
+=
format_log
(
test_loss
,
'test'
)
logging
(
'='
*
100
)
logging
(
log_str
)
logging
(
'='
*
100
)
pytorch_pretrained_bert/__init__.py
View file @
a69ec2c7
__version__
=
"0.5.0"
__version__
=
"0.5.0"
from
.tokenization
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.tokenization
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.tokenization_openai
import
OpenAIGPTTokenizer
from
.tokenization_openai
import
OpenAIGPTTokenizer
from
.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
TransfoXLCorpus
)
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertForMultipleChoice
,
BertForSequenceClassification
,
BertForMultipleChoice
,
BertForTokenClassification
,
BertForQuestionAnswering
)
BertForTokenClassification
,
BertForQuestionAnswering
)
from
.modeling_openai
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
from
.modeling_openai
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
)
from
.optimization
import
BertAdam
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
from
.optimization_openai
import
OpenAIAdam
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
a69ec2c7
...
@@ -12,23 +12,36 @@
...
@@ -12,23 +12,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Convert
OpenAI GPT checkpoint
."""
"""Convert
Transformer XL checkpoint and datasets
."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
import
re
import
sys
import
argparse
import
argparse
import
pickle
import
tensorflow
as
tf
import
tensorflow
as
tf
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
pytorch_pretrained_bert.modeling_transfo_xl
import
TransfoXLConfig
,
TransfoXLModel
,
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_pretrained_bert.modeling_transfo_xl
import
TransfoXLConfig
,
TransfoXLModel
,
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
VOCAB_NAME
,
CORPUS_NAME
# We do this to be able to load the python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
import
pytorch_pretrained_bert.tokenization_transfo_xl
as
data_utils
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Corpus
=
data_utils
.
TransfoXLCorpus
sys
.
modules
[
'data_utils'
]
=
data_utils
sys
.
modules
[
'vocabulary'
]
=
data_utils
def
build_tf_to_pytorch_map
(
model
,
config
):
def
build_tf_to_pytorch_map
(
model
,
config
):
""" A map of modules from TF to PyTorch """
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
"""
tf_to_pt_map
=
{}
tf_to_pt_map
=
{}
# Embeddings cutoffs
# Embeddings cutoffs
for
i
,
(
embed_l
,
proj_l
)
in
enumerate
(
zip
(
model
.
word_emb
.
emb_layers
,
model
.
word_emb
.
emb_projs
)):
for
i
,
(
embed_l
,
proj_l
)
in
enumerate
(
zip
(
model
.
word_emb
.
emb_layers
,
model
.
word_emb
.
emb_projs
)):
...
@@ -95,7 +108,23 @@ def build_tf_to_pytorch_map(model, config):
...
@@ -95,7 +108,23 @@ def build_tf_to_pytorch_map(model, config):
def
convert_transfo_xl_checkpoint_to_pytorch
(
tf_checkpoint_path
,
def
convert_transfo_xl_checkpoint_to_pytorch
(
tf_checkpoint_path
,
transfo_xl_config_file
,
transfo_xl_config_file
,
pytorch_dump_folder_path
):
pytorch_dump_folder_path
,
transfo_xl_dataset_file
):
if
transfo_xl_dataset_file
:
with
open
(
transfo_xl_dataset_file
,
"rb"
)
as
fp
:
corpus
=
pickle
.
load
(
fp
,
encoding
=
"latin1"
)
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_NAME
print
(
"Save vocabulary to {}"
.
format
(
pytorch_vocab_dump_path
))
torch
.
save
(
corpus
.
vocab
.
__dict__
,
pytorch_vocab_dump_path
)
corpus_dict_no_vocab
=
corpus
.
__dict__
corpus_dict_no_vocab
.
pop
(
'vocab'
,
None
)
pytorch_dataset_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CORPUS_NAME
print
(
"Save dataset to {}"
.
format
(
pytorch_dataset_dump_path
))
torch
.
save
(
corpus_dict_no_vocab
,
pytorch_dataset_dump_path
)
if
tf_checkpoint_path
:
config_path
=
os
.
path
.
abspath
(
transfo_xl_config_file
)
config_path
=
os
.
path
.
abspath
(
transfo_xl_config_file
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
...
@@ -161,22 +190,26 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...
@@ -161,22 +190,26 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
## Required parameters
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
""
,
type
=
str
,
help
=
"An optional path to a TensorFlow checkpoint path to be converted."
)
parser
.
add_argument
(
"--transfo_xl_config_file"
,
parser
.
add_argument
(
"--transfo_xl_config_file"
,
default
=
""
,
default
=
""
,
type
=
str
,
type
=
str
,
help
=
"
The
config json file corresponding to the pre-trained BERT model.
\n
"
help
=
"
An optional
config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
"This specifies the model architecture."
)
parser
.
add_argument
(
"--transfo_xl_dataset_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional dataset file to be converted in a vocabulary."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_transfo_xl_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
convert_transfo_xl_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
transfo_xl_config_file
,
args
.
transfo_xl_config_file
,
args
.
pytorch_dump_folder_path
)
args
.
pytorch_dump_folder_path
,
args
.
transfo_xl_dataset_file
)
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
a69ec2c7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" PyTorch Transformer XL model.
""" PyTorch Transformer XL model.
Directly a
dapted from https://github.com/kimiyoung/transformer-xl.
A
dapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""
"""
...
@@ -40,7 +40,7 @@ from .file_utils import cached_path
...
@@ -40,7 +40,7 @@ from .file_utils import cached_path
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'transfo-xl'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz"
,
'transfo-xl
-wt103
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl
-wt103
.tar.gz"
,
}
}
CONFIG_NAME
=
'transfo_xl_config.json'
CONFIG_NAME
=
'transfo_xl_config.json'
WEIGHTS_NAME
=
'pytorch_model.bin'
WEIGHTS_NAME
=
'pytorch_model.bin'
...
@@ -59,12 +59,13 @@ class TransfoXLConfig(object):
...
@@ -59,12 +59,13 @@ class TransfoXLConfig(object):
div_val
=
4
,
div_val
=
4
,
pre_lnorm
=
False
,
pre_lnorm
=
False
,
n_layer
=
18
,
n_layer
=
18
,
tgt_len
=
256
,
tgt_len
=
128
,
ext_len
=
0
,
ext_len
=
0
,
mem_len
=
256
,
mem_len
=
1600
,
same_length
=
False
,
clamp_len
=
1000
,
same_length
=
True
,
proj_share_all_but_first
=
True
,
attn_type
=
0
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
,
sample_softmax
=-
1
,
adaptive
=
True
,
adaptive
=
True
,
tie_weight
=
True
,
tie_weight
=
True
,
...
@@ -93,6 +94,7 @@ class TransfoXLConfig(object):
...
@@ -93,6 +94,7 @@ class TransfoXLConfig(object):
ext_len: length of the extended context
ext_len: length of the extended context
mem_len: length of the retained previous heads
mem_len: length of the retained previous heads
same_length: use the same attn length for all tokens
same_length: use the same attn length for all tokens
proj_share_all_but_first: True to share all but first projs, False not to share.
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
clamp_len: use the same pos embeddings after clamp_len
clamp_len: use the same pos embeddings after clamp_len
sample_softmax: number of samples in sampled softmax
sample_softmax: number of samples in sampled softmax
...
@@ -118,7 +120,10 @@ class TransfoXLConfig(object):
...
@@ -118,7 +120,10 @@ class TransfoXLConfig(object):
self
.
cutoffs
=
[]
self
.
cutoffs
=
[]
self
.
cutoffs
.
extend
(
cutoffs
)
self
.
cutoffs
.
extend
(
cutoffs
)
self
.
tie_weight
=
tie_weight
self
.
tie_weight
=
tie_weight
if
proj_share_all_but_first
:
self
.
tie_projs
=
[
False
]
+
[
True
]
*
len
(
self
.
cutoffs
)
self
.
tie_projs
=
[
False
]
+
[
True
]
*
len
(
self
.
cutoffs
)
else
:
self
.
tie_projs
=
[
False
]
+
[
False
]
*
len
(
self
.
cutoffs
)
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_embed
=
d_embed
self
.
d_embed
=
d_embed
self
.
d_head
=
d_head
self
.
d_head
=
d_head
...
@@ -915,21 +920,25 @@ class MemTransformerLM(nn.Module):
...
@@ -915,21 +920,25 @@ class MemTransformerLM(nn.Module):
return
core_out
,
new_mems
return
core_out
,
new_mems
def
forward
(
self
,
data
,
target
,
*
mems
):
def
forward
(
self
,
data
,
target
=
None
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
# them together.
if
not
mems
:
mems
=
self
.
init_mems
()
if
not
mems
:
mems
=
self
.
init_mems
()
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
if
target
is
None
:
if
new_mems
is
None
:
return
[
hidden
]
else
:
return
[
hidden
]
+
new_mems
tgt_len
=
target
.
size
(
0
)
pred_hid
=
hidden
[
-
tgt_len
:]
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
tie_weight
assert
self
.
tie_weight
logit
=
sample_logits
(
self
.
word_emb
,
logit
=
sample_logits
(
self
.
word_emb
,
self
.
out_layer
.
bias
,
target
,
pred_hid
,
self
.
sampler
)
self
.
out_layer
.
bias
,
target
,
pred_hid
,
self
.
sampler
)
loss
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
loss
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
else
:
else
:
loss
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
target
.
view
(
-
1
))
loss
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
target
.
view
(
-
1
))
...
@@ -1010,7 +1019,7 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -1010,7 +1019,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
pass
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
num_special_tokens
=
0
,
state_dict
=
None
,
cache_dir
=
None
,
def
from_pretrained
(
cls
,
pretrained_model_name
,
state_dict
=
None
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
...
@@ -1100,7 +1109,7 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -1100,7 +1109,7 @@ class TransfoXLPreTrainedModel(nn.Module):
for
name
,
child
in
module
.
_modules
.
items
():
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
load
(
child
,
prefix
+
name
+
'.'
)
load
(
model
.
transformer
if
hasattr
(
model
,
'transformer'
)
else
model
,
prefix
=
''
)
#
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
model
.
__class__
.
__name__
,
missing_keys
))
...
@@ -1110,9 +1119,6 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -1110,9 +1119,6 @@ class TransfoXLPreTrainedModel(nn.Module):
if
len
(
error_msgs
)
>
0
:
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
# Add additional embeddings for special tokens if needed
if
num_special_tokens
!=
config
.
n_special
:
model
.
set_num_special_tokens
(
num_special_tokens
)
if
tempdir
:
if
tempdir
:
# Clean up temp dir
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
shutil
.
rmtree
(
tempdir
)
...
...
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
a69ec2c7
...
@@ -14,15 +14,14 @@
...
@@ -14,15 +14,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" Tokenization classes for Transformer XL model.
""" Tokenization classes for Transformer XL model.
Directly a
dapted from https://github.com/kimiyoung/transformer-xl.
A
dapted from https://github.com/kimiyoung/transformer-xl.
"""
"""
import
os
import
os
import
re
import
glob
import
json
from
tqdm
import
tqdm
import
logging
import
logging
import
pickle
import
pickle
import
torch
from
collections
import
Counter
,
OrderedDict
from
collections
import
Counter
,
OrderedDict
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
...
@@ -30,16 +29,14 @@ from .file_utils import cached_path
...
@@ -30,16 +29,14 @@ from .file_utils import cached_path
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'transfo-xl'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/
openai-gpt
-vocab.
jso
n"
,
'transfo-xl
-wt103
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/
transfo-xl-wt103
-vocab.
bi
n"
,
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
VOCAB_NAME
=
'vocab.bin'
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"
,
}
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin"
,
'openai-gpt'
:
512
,
}
}
VOCAB_NAME
=
'vocab.json'
CORPUS_NAME
=
'corpus.bin'
MERGES_NAME
=
'merges.txt'
class
TransfoXLTokenizer
(
object
):
class
TransfoXLTokenizer
(
object
):
"""
"""
...
@@ -49,43 +46,36 @@ class TransfoXLTokenizer(object):
...
@@ -49,43 +46,36 @@ class TransfoXLTokenizer(object):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a TransfoXLTokenizer.
Instantiate a TransfoXLTokenizer.
Download and cache the vocabulary if needed
.
The TransfoXLTokenizer
.
"""
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
FileNotFoundError
:
except
FileNotFoundError
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {}
and {}
"
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
vocab_file
))
return
None
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
tokenizer
=
cls
(
*
inputs
,
**
kwargs
)
vocab_dict
=
torch
.
load
(
resolved_vocab_file
)
for
key
,
value
in
vocab_dict
.
items
():
tokenizer
.
__dict__
[
key
]
=
value
return
tokenizer
return
tokenizer
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
True
,
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
True
,
...
@@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator):
...
@@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator):
yield
batch
yield
batch
class
Corpus
(
object
):
class
TransfoXLCorpus
(
object
):
def
__init__
(
self
,
path
,
dataset
,
*
args
,
**
kwargs
):
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a pre-processed corpus.
"""
vocab
=
TransfoXLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
pretrained_model_name_or_path
in
PRETRAINED_CORPUS_ARCHIVE_MAP
:
corpus_file
=
PRETRAINED_CORPUS_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
corpus_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CORPUS_NAME
)
# redirect to the cache, if necessary
try
:
resolved_corpus_file
=
cached_path
(
corpus_file
,
cache_dir
=
cache_dir
)
except
FileNotFoundError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
corpus_file
))
return
None
if
resolved_corpus_file
==
corpus_file
:
logger
.
info
(
"loading corpus file {}"
.
format
(
corpus_file
))
else
:
logger
.
info
(
"loading corpus file {} from cache at {}"
.
format
(
corpus_file
,
resolved_corpus_file
))
# Instantiate tokenizer.
corpus
=
cls
(
*
inputs
,
**
kwargs
)
corpus_dict
=
torch
.
load
(
resolved_corpus_file
)
for
key
,
value
in
corpus_dict
.
items
():
corpus
.
__dict__
[
key
]
=
value
corpus
.
vocab
=
vocab
return
corpus
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
vocab
=
TransfoXLTokenizer
(
*
args
,
**
kwargs
)
self
.
dataset
=
None
self
.
train
=
None
self
.
valid
=
None
self
.
test
=
None
def
build_corpus
(
self
,
path
,
dataset
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
vocab
=
Vocab
(
*
args
,
**
kwargs
)
if
self
.
dataset
in
[
'ptb'
,
'wt2'
,
'enwik8'
,
'text8'
]:
if
self
.
dataset
in
[
'ptb'
,
'wt2'
,
'enwik8'
,
'text8'
]:
self
.
vocab
.
count_file
(
os
.
path
.
join
(
path
,
'train.txt'
))
self
.
vocab
.
count_file
(
os
.
path
.
join
(
path
,
'train.txt'
))
...
@@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset):
...
@@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset):
elif
dataset
in
[
'enwik8'
,
'text8'
]:
elif
dataset
in
[
'enwik8'
,
'text8'
]:
pass
pass
corpus
=
Corpus
(
datadir
,
dataset
,
**
kwargs
)
corpus
=
TransfoXL
Corpus
(
datadir
,
dataset
,
**
kwargs
)
torch
.
save
(
corpus
,
fn
)
torch
.
save
(
corpus
,
fn
)
return
corpus
return
corpus
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment