Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a69ec2c7
Commit
a69ec2c7
authored
Jan 15, 2019
by
thomwolf
Browse files
improved corpus and tokenization conversion - added evaluation script
parent
7d03c537
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
349 additions
and
124 deletions
+349
-124
examples/eval_transfo_xl.py
examples/eval_transfo_xl.py
+151
-0
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+2
-0
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+105
-72
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+23
-17
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+68
-35
No files found.
examples/eval_transfo_xl.py
0 → 100644
View file @
a69ec2c7
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
"""
import
os
import
sys
import
functools
import
argparse
import
time
import
math
import
torch
from
pytorch_pretrained_bert
import
TransfoXLModel
,
TransfoXLCorpus
def
logging
(
s
,
log_path
,
print_
=
True
,
log_
=
True
):
if
print_
:
print
(
s
)
if
log_
:
with
open
(
log_path
,
'a+'
)
as
f_log
:
f_log
.
write
(
s
+
'
\n
'
)
def
get_logger
(
log_path
,
**
kwargs
):
return
functools
.
partial
(
logging
,
log_path
=
log_path
,
**
kwargs
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# help='location of the data corpus')
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'transfo-xl-wt103'
,
choices
=
[
'transfo-xl-wt103'
],
#, 'lm1b', 'enwik8', 'text8'],
help
=
'pretrained model name'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'all'
,
choices
=
[
'all'
,
'valid'
,
'test'
],
help
=
'which split to evaluate'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
10
,
help
=
'batch size'
)
parser
.
add_argument
(
'--tgt_len'
,
type
=
int
,
default
=
5
,
help
=
'number of tokens to predict'
)
parser
.
add_argument
(
'--ext_len'
,
type
=
int
,
default
=
0
,
help
=
'length of the extended context'
)
parser
.
add_argument
(
'--mem_len'
,
type
=
int
,
default
=
0
,
help
=
'length of the retained previous heads'
)
parser
.
add_argument
(
'--clamp_len'
,
type
=
int
,
default
=-
1
,
help
=
'max positional embedding index'
)
parser
.
add_argument
(
'--cuda'
,
action
=
'store_true'
,
help
=
'use CUDA'
)
parser
.
add_argument
(
'--work_dir'
,
type
=
str
,
required
=
True
,
help
=
'path to the work_dir'
)
parser
.
add_argument
(
'--no_log'
,
action
=
'store_true'
,
help
=
'do not log the eval result'
)
parser
.
add_argument
(
'--same_length'
,
action
=
'store_true'
,
help
=
'set same length attention with masking'
)
args
=
parser
.
parse_args
()
assert
args
.
ext_len
>=
0
,
'extended context length must be non-negative'
device
=
torch
.
device
(
"cuda"
if
args
.
cuda
else
"cpu"
)
# Get logger
logging
=
get_logger
(
os
.
path
.
join
(
args
.
work_dir
,
'log.txt'
),
log_
=
not
args
.
no_log
)
# Load dataset
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
ntokens
=
len
(
corpus
.
vocab
)
va_iter
=
corpus
.
get_iterator
(
'valid'
,
args
.
batch_size
,
args
.
tgt_len
,
device
=
device
,
ext_len
=
args
.
ext_len
)
te_iter
=
corpus
.
get_iterator
(
'test'
,
args
.
batch_size
,
args
.
tgt_len
,
device
=
device
,
ext_len
=
args
.
ext_len
)
# Load the best saved model.
# with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
# model = torch.load(f)
# model.backward_compatible()
model
=
TransfoXLModel
.
from_pretrained
(
args
.
model_name
)
model
=
model
.
to
(
device
)
logging
(
'Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'
.
format
(
args
.
batch_size
,
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
,
args
.
clamp_len
))
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
if
args
.
clamp_len
>
0
:
model
.
clamp_len
=
args
.
clamp_len
if
args
.
same_length
:
model
.
same_length
=
True
###############################################################################
# Evaluation code
###############################################################################
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
model
.
eval
()
total_len
,
total_loss
=
0
,
0.
start_time
=
time
.
time
()
with
torch
.
no_grad
():
mems
=
tuple
()
for
idx
,
(
data
,
target
,
seq_len
)
in
enumerate
(
eval_iter
):
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
item
()
total_len
+=
seq_len
total_time
=
time
.
time
()
-
start_time
logging
(
'Time : {:.2f}s, {:.2f}ms/segment'
.
format
(
total_time
,
1000
*
total_time
/
(
idx
+
1
)))
return
total_loss
/
total_len
# Run on test data.
if
args
.
split
==
'all'
:
test_loss
=
evaluate
(
te_iter
)
valid_loss
=
evaluate
(
va_iter
)
elif
args
.
split
==
'valid'
:
valid_loss
=
evaluate
(
va_iter
)
test_loss
=
None
elif
args
.
split
==
'test'
:
test_loss
=
evaluate
(
te_iter
)
valid_loss
=
None
def
format_log
(
loss
,
split
):
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
log_str
=
'| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '
.
format
(
split
,
loss
,
loss
/
math
.
log
(
2
))
else
:
log_str
=
'| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '
.
format
(
split
,
loss
,
math
.
exp
(
loss
))
return
log_str
log_str
=
''
if
valid_loss
is
not
None
:
log_str
+=
format_log
(
valid_loss
,
'valid'
)
if
test_loss
is
not
None
:
log_str
+=
format_log
(
test_loss
,
'test'
)
logging
(
'='
*
100
)
logging
(
log_str
)
logging
(
'='
*
100
)
pytorch_pretrained_bert/__init__.py
View file @
a69ec2c7
__version__
=
"0.5.0"
from
.tokenization
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.tokenization_openai
import
OpenAIGPTTokenizer
from
.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
TransfoXLCorpus
)
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertForMultipleChoice
,
BertForTokenClassification
,
BertForQuestionAnswering
)
from
.modeling_openai
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
)
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
a69ec2c7
...
...
@@ -12,23 +12,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert
OpenAI GPT checkpoint
."""
"""Convert
Transformer XL checkpoint and datasets
."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
re
import
sys
import
argparse
import
pickle
import
tensorflow
as
tf
import
torch
import
numpy
as
np
from
pytorch_pretrained_bert.modeling_transfo_xl
import
TransfoXLConfig
,
TransfoXLModel
,
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
VOCAB_NAME
,
CORPUS_NAME
# We do this to be able to load the python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
import
pytorch_pretrained_bert.tokenization_transfo_xl
as
data_utils
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Corpus
=
data_utils
.
TransfoXLCorpus
sys
.
modules
[
'data_utils'
]
=
data_utils
sys
.
modules
[
'vocabulary'
]
=
data_utils
def
build_tf_to_pytorch_map
(
model
,
config
):
""" A map of modules from TF to PyTorch """
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
"""
tf_to_pt_map
=
{}
# Embeddings cutoffs
for
i
,
(
embed_l
,
proj_l
)
in
enumerate
(
zip
(
model
.
word_emb
.
emb_layers
,
model
.
word_emb
.
emb_projs
)):
...
...
@@ -95,88 +108,108 @@ def build_tf_to_pytorch_map(model, config):
def
convert_transfo_xl_checkpoint_to_pytorch
(
tf_checkpoint_path
,
transfo_xl_config_file
,
pytorch_dump_folder_path
):
config_path
=
os
.
path
.
abspath
(
transfo_xl_config_file
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting Transformer XL checkpoint from {} with config at {}"
.
format
(
tf_path
,
config_path
))
# Initialise PyTorch model
# Construct model
if
transfo_xl_config_file
==
""
:
config
=
TransfoXLConfig
()
else
:
config
=
TransfoXLConfig
(
transfo_xl_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
TransfoXLModel
(
config
)
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
.
transformer
,
config
)
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
tf_weights
=
{}
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
tf_weights
[
name
]
=
array
for
name
,
pointer
in
tf_to_pt_map
.
items
():
assert
name
in
tf_weights
array
=
tf_weights
[
name
]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
'kernel'
in
name
or
'proj_W'
in
name
:
array
=
np
.
transpose
(
array
)
if
(
'r_r_bias'
in
name
or
'r_w_bias'
in
name
)
and
len
(
pointer
)
>
1
:
# Here we will split the TF weigths
assert
len
(
pointer
)
==
array
.
shape
[
0
]
for
i
,
p_i
in
enumerate
(
pointer
):
arr_i
=
array
[
i
,
...]
try
:
assert
p_i
.
shape
==
arr_i
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
p_i
.
shape
,
arr_i
.
shape
)
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
continue
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
pytorch_dump_folder_path
,
transfo_xl_dataset_file
):
if
transfo_xl_dataset_file
:
with
open
(
transfo_xl_dataset_file
,
"rb"
)
as
fp
:
corpus
=
pickle
.
load
(
fp
,
encoding
=
"latin1"
)
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_NAME
print
(
"Save vocabulary to {}"
.
format
(
pytorch_vocab_dump_path
))
torch
.
save
(
corpus
.
vocab
.
__dict__
,
pytorch_vocab_dump_path
)
corpus_dict_no_vocab
=
corpus
.
__dict__
corpus_dict_no_vocab
.
pop
(
'vocab'
,
None
)
pytorch_dataset_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CORPUS_NAME
print
(
"Save dataset to {}"
.
format
(
pytorch_dataset_dump_path
))
torch
.
save
(
corpus_dict_no_vocab
,
pytorch_dataset_dump_path
)
if
tf_checkpoint_path
:
config_path
=
os
.
path
.
abspath
(
transfo_xl_config_file
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting Transformer XL checkpoint from {} with config at {}"
.
format
(
tf_path
,
config_path
))
# Initialise PyTorch model
# Construct model
if
transfo_xl_config_file
==
""
:
config
=
TransfoXLConfig
()
else
:
config
=
TransfoXLConfig
(
transfo_xl_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
TransfoXLModel
(
config
)
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
.
transformer
,
config
)
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
tf_weights
=
{}
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
tf_weights
[
name
]
=
array
for
name
,
pointer
in
tf_to_pt_map
.
items
():
assert
name
in
tf_weights
array
=
tf_weights
[
name
]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
'kernel'
in
name
or
'proj_W'
in
name
:
array
=
np
.
transpose
(
array
)
if
(
'r_r_bias'
in
name
or
'r_w_bias'
in
name
)
and
len
(
pointer
)
>
1
:
# Here we will split the TF weigths
assert
len
(
pointer
)
==
array
.
shape
[
0
]
for
i
,
p_i
in
enumerate
(
pointer
):
arr_i
=
array
[
i
,
...]
try
:
assert
p_i
.
shape
==
arr_i
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
p_i
.
shape
,
arr_i
.
shape
)
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
continue
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
""
,
type
=
str
,
help
=
"An optional path to a TensorFlow checkpoint path to be converted."
)
parser
.
add_argument
(
"--transfo_xl_config_file"
,
default
=
""
,
type
=
str
,
help
=
"
The
config json file corresponding to the pre-trained BERT model.
\n
"
help
=
"
An optional
config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--transfo_xl_dataset_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional dataset file to be converted in a vocabulary."
)
args
=
parser
.
parse_args
()
convert_transfo_xl_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
transfo_xl_config_file
,
args
.
pytorch_dump_folder_path
)
args
.
pytorch_dump_folder_path
,
args
.
transfo_xl_dataset_file
)
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
a69ec2c7
...
...
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model.
Directly a
dapted from https://github.com/kimiyoung/transformer-xl.
A
dapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""
...
...
@@ -40,7 +40,7 @@ from .file_utils import cached_path
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'transfo-xl'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz"
,
'transfo-xl
-wt103
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl
-wt103
.tar.gz"
,
}
CONFIG_NAME
=
'transfo_xl_config.json'
WEIGHTS_NAME
=
'pytorch_model.bin'
...
...
@@ -59,12 +59,13 @@ class TransfoXLConfig(object):
div_val
=
4
,
pre_lnorm
=
False
,
n_layer
=
18
,
tgt_len
=
256
,
tgt_len
=
128
,
ext_len
=
0
,
mem_len
=
256
,
same_length
=
False
,
mem_len
=
1600
,
clamp_len
=
1000
,
same_length
=
True
,
proj_share_all_but_first
=
True
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
,
adaptive
=
True
,
tie_weight
=
True
,
...
...
@@ -93,6 +94,7 @@ class TransfoXLConfig(object):
ext_len: length of the extended context
mem_len: length of the retained previous heads
same_length: use the same attn length for all tokens
proj_share_all_but_first: True to share all but first projs, False not to share.
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
clamp_len: use the same pos embeddings after clamp_len
sample_softmax: number of samples in sampled softmax
...
...
@@ -118,7 +120,10 @@ class TransfoXLConfig(object):
self
.
cutoffs
=
[]
self
.
cutoffs
.
extend
(
cutoffs
)
self
.
tie_weight
=
tie_weight
self
.
tie_projs
=
[
False
]
+
[
True
]
*
len
(
self
.
cutoffs
)
if
proj_share_all_but_first
:
self
.
tie_projs
=
[
False
]
+
[
True
]
*
len
(
self
.
cutoffs
)
else
:
self
.
tie_projs
=
[
False
]
+
[
False
]
*
len
(
self
.
cutoffs
)
self
.
d_model
=
d_model
self
.
d_embed
=
d_embed
self
.
d_head
=
d_head
...
...
@@ -423,7 +428,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
r_head_k
=
r_head_k
.
view
(
rlen
,
self
.
n_head
,
self
.
d_head
)
# qlen x n_head x d_head
#### compute attention score
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
rr_head_q
=
w_head_q
+
self
.
r_r_bias
...
...
@@ -915,21 +920,25 @@ class MemTransformerLM(nn.Module):
return
core_out
,
new_mems
def
forward
(
self
,
data
,
target
,
*
mems
):
def
forward
(
self
,
data
,
target
=
None
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if
not
mems
:
mems
=
self
.
init_mems
()
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
if
target
is
None
:
if
new_mems
is
None
:
return
[
hidden
]
else
:
return
[
hidden
]
+
new_mems
tgt_len
=
target
.
size
(
0
)
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
tie_weight
logit
=
sample_logits
(
self
.
word_emb
,
self
.
out_layer
.
bias
,
target
,
pred_hid
,
self
.
sampler
)
logit
=
sample_logits
(
self
.
word_emb
,
self
.
out_layer
.
bias
,
target
,
pred_hid
,
self
.
sampler
)
loss
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
else
:
loss
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
target
.
view
(
-
1
))
...
...
@@ -1010,7 +1019,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
num_special_tokens
=
0
,
state_dict
=
None
,
cache_dir
=
None
,
def
from_pretrained
(
cls
,
pretrained_model_name
,
state_dict
=
None
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
...
...
@@ -1100,7 +1109,7 @@ class TransfoXLPreTrainedModel(nn.Module):
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
load
(
model
.
transformer
if
hasattr
(
model
,
'transformer'
)
else
model
,
prefix
=
''
)
#
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
...
...
@@ -1110,9 +1119,6 @@ class TransfoXLPreTrainedModel(nn.Module):
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
# Add additional embeddings for special tokens if needed
if
num_special_tokens
!=
config
.
n_special
:
model
.
set_num_special_tokens
(
num_special_tokens
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
...
...
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
a69ec2c7
...
...
@@ -14,15 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization classes for Transformer XL model.
Directly a
dapted from https://github.com/kimiyoung/transformer-xl.
A
dapted from https://github.com/kimiyoung/transformer-xl.
"""
import
os
import
re
import
json
from
tqdm
import
tqdm
import
glob
import
logging
import
pickle
import
torch
from
collections
import
Counter
,
OrderedDict
from
.file_utils
import
cached_path
...
...
@@ -30,16 +29,14 @@ from .file_utils import cached_path
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'transfo-xl'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/
openai-gpt
-vocab.
jso
n"
,
'transfo-xl
-wt103
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/
transfo-xl-wt103
-vocab.
bi
n"
,
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'openai-gpt'
:
512
,
VOCAB_NAME
=
'vocab.bin'
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin"
,
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
CORPUS_NAME
=
'corpus.bin'
class
TransfoXLTokenizer
(
object
):
"""
...
...
@@ -49,43 +46,36 @@ class TransfoXLTokenizer(object):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLTokenizer.
Download and cache the vocabulary if needed
.
The TransfoXLTokenizer
.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
FileNotFoundError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {}
and {}
"
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
tokenizer
=
cls
(
*
inputs
,
**
kwargs
)
vocab_dict
=
torch
.
load
(
resolved_vocab_file
)
for
key
,
value
in
vocab_dict
.
items
():
tokenizer
.
__dict__
[
key
]
=
value
return
tokenizer
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
True
,
...
...
@@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator):
yield
batch
class
Corpus
(
object
):
def
__init__
(
self
,
path
,
dataset
,
*
args
,
**
kwargs
):
class
TransfoXLCorpus
(
object
):
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a pre-processed corpus.
"""
vocab
=
TransfoXLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
pretrained_model_name_or_path
in
PRETRAINED_CORPUS_ARCHIVE_MAP
:
corpus_file
=
PRETRAINED_CORPUS_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
corpus_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CORPUS_NAME
)
# redirect to the cache, if necessary
try
:
resolved_corpus_file
=
cached_path
(
corpus_file
,
cache_dir
=
cache_dir
)
except
FileNotFoundError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
corpus_file
))
return
None
if
resolved_corpus_file
==
corpus_file
:
logger
.
info
(
"loading corpus file {}"
.
format
(
corpus_file
))
else
:
logger
.
info
(
"loading corpus file {} from cache at {}"
.
format
(
corpus_file
,
resolved_corpus_file
))
# Instantiate tokenizer.
corpus
=
cls
(
*
inputs
,
**
kwargs
)
corpus_dict
=
torch
.
load
(
resolved_corpus_file
)
for
key
,
value
in
corpus_dict
.
items
():
corpus
.
__dict__
[
key
]
=
value
corpus
.
vocab
=
vocab
return
corpus
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
vocab
=
TransfoXLTokenizer
(
*
args
,
**
kwargs
)
self
.
dataset
=
None
self
.
train
=
None
self
.
valid
=
None
self
.
test
=
None
def
build_corpus
(
self
,
path
,
dataset
):
self
.
dataset
=
dataset
self
.
vocab
=
Vocab
(
*
args
,
**
kwargs
)
if
self
.
dataset
in
[
'ptb'
,
'wt2'
,
'enwik8'
,
'text8'
]:
self
.
vocab
.
count_file
(
os
.
path
.
join
(
path
,
'train.txt'
))
...
...
@@ -443,20 +476,20 @@ class Corpus(object):
os
.
path
.
join
(
path
,
'train.txt'
),
ordered
=
True
)
self
.
valid
=
self
.
vocab
.
encode_file
(
os
.
path
.
join
(
path
,
'valid.txt'
),
ordered
=
True
)
self
.
test
=
self
.
vocab
.
encode_file
(
self
.
test
=
self
.
vocab
.
encode_file
(
os
.
path
.
join
(
path
,
'test.txt'
),
ordered
=
True
)
elif
self
.
dataset
in
[
'enwik8'
,
'text8'
]:
self
.
train
=
self
.
vocab
.
encode_file
(
os
.
path
.
join
(
path
,
'train.txt'
),
ordered
=
True
,
add_eos
=
False
)
self
.
valid
=
self
.
vocab
.
encode_file
(
os
.
path
.
join
(
path
,
'valid.txt'
),
ordered
=
True
,
add_eos
=
False
)
self
.
test
=
self
.
vocab
.
encode_file
(
self
.
test
=
self
.
vocab
.
encode_file
(
os
.
path
.
join
(
path
,
'test.txt'
),
ordered
=
True
,
add_eos
=
False
)
elif
self
.
dataset
==
'lm1b'
:
self
.
train
=
train_paths
self
.
valid
=
self
.
vocab
.
encode_file
(
os
.
path
.
join
(
path
,
'valid.txt'
),
ordered
=
False
,
add_double_eos
=
True
)
self
.
test
=
self
.
vocab
.
encode_file
(
self
.
test
=
self
.
vocab
.
encode_file
(
os
.
path
.
join
(
path
,
'test.txt'
),
ordered
=
False
,
add_double_eos
=
True
)
def
get_iterator
(
self
,
split
,
*
args
,
**
kwargs
):
...
...
@@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset):
elif
dataset
in
[
'enwik8'
,
'text8'
]:
pass
corpus
=
Corpus
(
datadir
,
dataset
,
**
kwargs
)
corpus
=
TransfoXL
Corpus
(
datadir
,
dataset
,
**
kwargs
)
torch
.
save
(
corpus
,
fn
)
return
corpus
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment