Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
448937c0
Commit
448937c0
authored
Feb 06, 2019
by
thomwolf
Browse files
python 2 compatibility
parent
ba37ddc5
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
245 additions
and
183 deletions
+245
-183
examples/eval_transfo_xl.py
examples/eval_transfo_xl.py
+16
-16
examples/run_classifier.py
examples/run_classifier.py
+16
-12
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+21
-22
examples/run_squad.py
examples/run_squad.py
+19
-11
examples/run_swag.py
examples/run_swag.py
+32
-24
pytorch_pretrained_bert/__main__.py
pytorch_pretrained_bert/__main__.py
+2
-2
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
...h_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
+8
-4
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+16
-10
pytorch_pretrained_bert/file_utils.py
pytorch_pretrained_bert/file_utils.py
+41
-40
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+18
-13
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+7
-3
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+12
-7
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+5
-6
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+11
-5
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+16
-5
setup.py
setup.py
+2
-1
tests/tokenization_test.py
tests/tokenization_test.py
+3
-2
No files found.
examples/eval_transfo_xl.py
View file @
448937c0
...
...
@@ -17,26 +17,26 @@
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
sys
import
functools
import
argparse
import
logging
import
time
import
math
import
sys
from
io
import
open
import
torch
from
pytorch_pretrained_bert
import
TransfoXLModel
,
TransfoXLCorpus
def
logging
(
s
,
log_path
,
print_
=
True
,
log_
=
True
):
if
print_
:
print
(
s
)
if
log_
:
with
open
(
log_path
,
'a+'
)
as
f_log
:
f_log
.
write
(
s
+
'
\n
'
)
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
get_logger
(
log_path
,
**
kwargs
):
return
functools
.
partial
(
logging
,
log_path
=
log_path
,
**
kwargs
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
...
...
@@ -71,8 +71,8 @@ assert args.ext_len >= 0, 'extended context length must be non-negative'
device
=
torch
.
device
(
"cuda"
if
args
.
cuda
else
"cpu"
)
# Get logger
logging
=
get_logger
(
os
.
path
.
join
(
args
.
work_dir
,
'log.txt'
),
log_
=
not
args
.
no_log
)
#
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
#
log_=not args.no_log)
# Load dataset
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
...
...
@@ -90,7 +90,7 @@ te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
model
=
TransfoXLModel
.
from_pretrained
(
args
.
model_name
)
model
=
model
.
to
(
device
)
logg
ing
(
'Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'
.
format
(
logg
er
.
info
(
'Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'
.
format
(
args
.
batch_size
,
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
,
args
.
clamp_len
))
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
...
...
@@ -116,7 +116,7 @@ def evaluate(eval_iter):
total_loss
+=
seq_len
*
loss
.
item
()
total_len
+=
seq_len
total_time
=
time
.
time
()
-
start_time
logg
ing
(
'Time : {:.2f}s, {:.2f}ms/segment'
.
format
(
logg
er
.
info
(
'Time : {:.2f}s, {:.2f}ms/segment'
.
format
(
total_time
,
1000
*
total_time
/
(
idx
+
1
)))
return
total_loss
/
total_len
...
...
@@ -146,6 +146,6 @@ if valid_loss is not None:
if
test_loss
is
not
None
:
log_str
+=
format_log
(
test_loss
,
'test'
)
logg
ing
(
'='
*
100
)
logg
ing
(
log_str
)
logg
ing
(
'='
*
100
)
logg
er
.
info
(
'='
*
100
)
logg
er
.
info
(
log_str
)
logg
er
.
info
(
'='
*
100
)
examples/run_classifier.py
View file @
448937c0
...
...
@@ -15,26 +15,27 @@
# limitations under the License.
"""BERT finetuning runner."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
csv
import
os
import
logging
import
argparse
import
os
import
random
from
tqdm
import
tqdm
,
trange
import
sys
from
io
import
open
import
numpy
as
np
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
...
@@ -91,10 +92,12 @@ class DataProcessor(object):
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
input_file
,
"r
b
"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
return
lines
...
...
@@ -429,7 +432,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
task_name
=
args
.
task_name
.
lower
()
...
...
@@ -451,7 +455,7 @@ def main():
# Prepare model
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
),
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
)
)
,
num_labels
=
num_labels
)
if
args
.
fp16
:
model
.
half
()
...
...
examples/run_lm_finetuning.py
View file @
448937c0
...
...
@@ -15,26 +15,23 @@
# limitations under the License.
"""BERT finetuning runner."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
logging
import
argparse
from
tqdm
import
tqdm
,
trange
import
logging
import
os
import
random
from
io
import
open
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
,
RandomSampler
from
torch.utils.data
import
DataLoader
,
Dataset
,
RandomSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
torch.utils.data
import
Dataset
import
random
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
...
@@ -185,16 +182,16 @@ class BERTDataset(Dataset):
if
self
.
line_buffer
is
None
:
# read first non-empty line of file
while
t1
==
""
:
t1
=
self
.
file
.
__next__
(
).
strip
()
t2
=
self
.
file
.
__next__
(
).
strip
()
t1
=
next
(
self
.
file
).
strip
()
t2
=
next
(
self
.
file
).
strip
()
else
:
# use t2 from previous iteration as new t1
t1
=
self
.
line_buffer
t2
=
self
.
file
.
__next__
(
).
strip
()
t2
=
next
(
self
.
file
).
strip
()
# skip empty rows that are used for separating documents and keep track of current doc id
while
t2
==
""
or
t1
==
""
:
t1
=
self
.
file
.
__next__
(
).
strip
()
t2
=
self
.
file
.
__next__
(
).
strip
()
t1
=
next
(
self
.
file
).
strip
()
t2
=
next
(
self
.
file
).
strip
()
self
.
current_doc
=
self
.
current_doc
+
1
self
.
line_buffer
=
t2
...
...
@@ -228,15 +225,15 @@ class BERTDataset(Dataset):
def
get_next_line
(
self
):
""" Gets next line of random_file and starts over when reaching end of file"""
try
:
line
=
self
.
random_file
.
__next__
(
).
strip
()
line
=
next
(
self
.
random_file
).
strip
()
#keep track of which document we are currently looking at to later avoid having the same doc as t1
if
line
==
""
:
self
.
current_random_doc
=
self
.
current_random_doc
+
1
line
=
self
.
random_file
.
__next__
(
).
strip
()
line
=
next
(
self
.
random_file
).
strip
()
except
StopIteration
:
self
.
random_file
.
close
()
self
.
random_file
=
open
(
self
.
corpus_path
,
"r"
,
encoding
=
self
.
encoding
)
line
=
self
.
random_file
.
__next__
(
).
strip
()
line
=
next
(
self
.
random_file
).
strip
()
return
line
...
...
@@ -425,6 +422,7 @@ def main():
help
=
"The output directory where the model checkpoints will be written."
)
## Other parameters
parser
.
add_argument
(
"--do_lower_case"
,
action
=
'store_true'
,
help
=
"Set this flag if you are using an uncased model."
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
...
...
@@ -513,7 +511,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
...
...
@@ -579,7 +578,7 @@ def main():
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
train_dataset
)
else
:
#TODO: check if this works with current data generator from disk that relies on
file.__next__
#TODO: check if this works with current data generator from disk that relies on
next(file)
# (it doesn't return item back by index)
train_sampler
=
DistributedSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
...
...
@@ -643,4 +642,4 @@ def accuracy(out, labels):
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
main
()
examples/run_squad.py
View file @
448937c0
...
...
@@ -15,29 +15,36 @@
# limitations under the License.
"""Run BERT on SQuAD."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
collections
import
logging
import
json
import
logging
import
math
import
os
import
random
import
pickle
from
tqdm
import
tqdm
,
trange
import
sys
from
io
import
open
import
numpy
as
np
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.
tokenization
import
whitespace_tokenize
,
BasicTokenizer
,
BertTokenizer
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.tokenization
import
(
BasicTokenizer
,
BertTokenizer
,
whitespace_tokenize
)
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
...
@@ -784,7 +791,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory () already exists and is not empty."
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
...
...
@@ -798,7 +806,7 @@ def main():
# Prepare model
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
))
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
))
)
if
args
.
fp16
:
model
.
half
()
...
...
examples/run_swag.py
View file @
448937c0
...
...
@@ -15,22 +15,25 @@
# limitations under the License.
"""BERT finetuning runner."""
import
argparse
import
csv
import
logging
import
os
import
argparse
import
random
from
tqdm
import
tqdm
,
trange
import
csv
import
sys
from
io
import
open
import
numpy
as
np
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForMultipleChoice
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
...
@@ -65,17 +68,17 @@ class SwagExample(object):
def
__repr__
(
self
):
l
=
[
f
"swag_id:
{
self
.
swag_id
}
"
,
f
"context_sentence:
{
self
.
context_sentence
}
"
,
f
"start_ending:
{
self
.
start_ending
}
"
,
f
"ending_0:
{
self
.
endings
[
0
]
}
"
,
f
"ending_1:
{
self
.
endings
[
1
]
}
"
,
f
"ending_2:
{
self
.
endings
[
2
]
}
"
,
f
"ending_3:
{
self
.
endings
[
3
]
}
"
,
"swag_id: {
}"
.
format
(
self
.
swag_id
)
,
"context_sentence: {
}"
.
format
(
self
.
context_sentence
)
,
"start_ending: {
}"
.
format
(
self
.
start_ending
)
,
"ending_0: {
}"
.
format
(
self
.
endings
[
0
]
)
,
"ending_1: {
}"
.
format
(
self
.
endings
[
1
]
)
,
"ending_2: {
}"
.
format
(
self
.
endings
[
2
]
)
,
"ending_3: {
}"
.
format
(
self
.
endings
[
3
]
)
,
]
if
self
.
label
is
not
None
:
l
.
append
(
f
"label:
{
self
.
label
}
"
)
l
.
append
(
"label: {
}"
.
format
(
self
.
label
)
)
return
", "
.
join
(
l
)
...
...
@@ -102,7 +105,11 @@ class InputFeatures(object):
def
read_swag_examples
(
input_file
,
is_training
):
with
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
reader
=
csv
.
reader
(
f
)
lines
=
list
(
reader
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
if
is_training
and
lines
[
0
][
-
1
]
!=
'label'
:
raise
ValueError
(
...
...
@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
label
=
example
.
label
if
example_index
<
5
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
f
"swag_id:
{
example
.
swag_id
}
"
)
logger
.
info
(
"swag_id: {
}"
.
format
(
example
.
swag_id
)
)
for
choice_idx
,
(
tokens
,
input_ids
,
input_mask
,
segment_ids
)
in
enumerate
(
choices_features
):
logger
.
info
(
f
"choice:
{
choice_idx
}
"
)
logger
.
info
(
f
"tokens:
{
' '
.
join
(
tokens
)
}
"
)
logger
.
info
(
f
"input_ids:
{
' '
.
join
(
map
(
str
,
input_ids
))
}
"
)
logger
.
info
(
f
"input_mask:
{
' '
.
join
(
map
(
str
,
input_mask
))
}
"
)
logger
.
info
(
f
"segment_ids:
{
' '
.
join
(
map
(
str
,
segment_ids
))
}
"
)
logger
.
info
(
"choice: {
}"
.
format
(
choice_idx
)
)
logger
.
info
(
"tokens: {
}"
.
format
(
' '
.
join
(
tokens
)
)
)
logger
.
info
(
"input_ids: {
}"
.
format
(
' '
.
join
(
map
(
str
,
input_ids
))
)
)
logger
.
info
(
"input_mask: {
}"
.
format
(
' '
.
join
(
map
(
str
,
input_mask
))
)
)
logger
.
info
(
"segment_ids: {
}"
.
format
(
' '
.
join
(
map
(
str
,
segment_ids
))
)
)
if
is_training
:
logger
.
info
(
f
"label:
{
label
}
"
)
logger
.
info
(
"label: {
}"
.
format
(
label
)
)
features
.
append
(
InputFeatures
(
...
...
@@ -349,7 +356,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
...
...
@@ -362,7 +370,7 @@ def main():
# Prepare model
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
),
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
)
)
,
num_choices
=
4
)
if
args
.
fp16
:
model
.
half
()
...
...
pytorch_pretrained_bert/__main__.py
View file @
448937c0
...
...
@@ -15,7 +15,7 @@ def main():
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
try
:
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
...
...
@@ -43,7 +43,7 @@ def main():
else
:
try
:
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
...
...
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
View file @
448937c0
...
...
@@ -14,14 +14,18 @@
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
from
io
import
open
import
torch
from
pytorch_pretrained_bert.modeling_openai
import
load_tf_weights_in_openai_gpt
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_pretrained_bert.modeling_openai
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
load_tf_weights_in_openai_gpt
)
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
# Construct model
...
...
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
448937c0
...
...
@@ -14,25 +14,31 @@
# limitations under the License.
"""Convert Transformer XL checkpoint and datasets."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
os
import
sys
import
argparse
import
pickle
from
io
import
open
import
tensorflow
as
tf
import
torch
import
numpy
as
np
from
pytorch_pretrained_bert.modeling_transfo_xl
import
TransfoXLConfig
,
TransfoXLModel
,
CONFIG_NAME
,
WEIGHTS_NAME
,
load_tf_weights_in_transfo_xl
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
VOCAB_NAME
,
CORPUS_NAME
import
pytorch_pretrained_bert.tokenization_transfo_xl
as
data_utils
from
pytorch_pretrained_bert.modeling_transfo_xl
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
TransfoXLConfig
,
TransfoXLModel
,
load_tf_weights_in_transfo_xl
)
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
(
CORPUS_NAME
,
VOCAB_NAME
)
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
# We do this to be able to load the python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
import
pytorch_pretrained_bert.tokenization_transfo_xl
as
data_utils
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Corpus
=
data_utils
.
TransfoXLCorpus
sys
.
modules
[
'data_utils'
]
=
data_utils
...
...
pytorch_pretrained_bert/file_utils.py
View file @
448937c0
...
...
@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
os
import
json
import
logging
import
os
import
shutil
import
tempfile
import
json
from
urllib.parse
import
urlparse
from
pathlib
import
Path
from
typing
import
Optional
,
Tuple
,
Union
,
IO
,
Callable
,
Set
from
hashlib
import
sha256
from
functools
import
wraps
from
tqdm
import
tqdm
from
hashlib
import
sha256
from
io
import
open
import
boto3
from
botocore.exceptions
import
ClientError
import
requests
from
botocore.exceptions
import
ClientError
from
tqdm
import
tqdm
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
try
:
from
urllib.parse
import
urlparse
except
ImportError
:
from
urlparse
import
urlparse
try
:
from
pathlib
import
Path
PYTORCH_PRETRAINED_BERT_CACHE
=
Path
(
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
Path
.
home
()
/
'.pytorch_pretrained_bert'
))
except
ImportError
:
PYTORCH_PRETRAINED_BERT_CACHE
=
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
os
.
path
.
join
(
os
.
path
.
expanduser
(
"~"
),
'.pytorch_pretrained_bert'
))
PYTORCH_PRETRAINED_BERT_CACHE
=
Path
(
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
Path
.
home
()
/
'.pytorch_pretrained_bert'
))
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
def
url_to_filename
(
url
:
str
,
etag
:
str
=
None
)
->
str
:
def
url_to_filename
(
url
,
etag
=
None
)
:
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
...
...
@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str:
return
filename
def
filename_to_url
(
filename
:
str
,
cache_dir
:
Union
[
str
,
Path
]
=
None
)
->
Tuple
[
str
,
str
]
:
def
filename_to_url
(
filename
,
cache_dir
=
None
)
:
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``
FileNotFound
Error`` if `filename` or its stored metadata do not exist.
Raise ``
Environment
Error`` if `filename` or its stored metadata do not exist.
"""
if
cache_dir
is
None
:
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
if
not
os
.
path
.
exists
(
cache_path
):
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
cache_path
))
raise
Environment
Error
(
"file {} not found"
.
format
(
cache_path
))
meta_path
=
cache_path
+
'.json'
if
not
os
.
path
.
exists
(
meta_path
):
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
meta_path
))
raise
Environment
Error
(
"file {} not found"
.
format
(
meta_path
))
with
open
(
meta_path
)
as
meta_file
:
with
open
(
meta_path
,
encoding
=
"utf-8"
)
as
meta_file
:
metadata
=
json
.
load
(
meta_file
)
url
=
metadata
[
'url'
]
etag
=
metadata
[
'etag'
]
...
...
@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
return
url
,
etag
def
cached_path
(
url_or_filename
:
Union
[
str
,
Path
],
cache_dir
:
Union
[
str
,
Path
]
=
None
)
->
str
:
def
cached_path
(
url_or_filename
,
cache_dir
=
None
)
:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
...
...
@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
"""
if
cache_dir
is
None
:
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
isinstance
(
url_or_filename
,
Path
):
url_or_filename
=
str
(
url_or_filename
)
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
parsed
=
urlparse
(
url_or_filename
)
...
...
@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
return
url_or_filename
elif
parsed
.
scheme
==
''
:
# File, but it doesn't exist.
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
url_or_filename
))
raise
Environment
Error
(
"file {} not found"
.
format
(
url_or_filename
))
else
:
# Something unknown
raise
ValueError
(
"unable to parse {} as a URL or as a local path"
.
format
(
url_or_filename
))
def
split_s3_path
(
url
:
str
)
->
Tuple
[
str
,
str
]
:
def
split_s3_path
(
url
)
:
"""Split a full s3 path into the bucket name and path."""
parsed
=
urlparse
(
url
)
if
not
parsed
.
netloc
or
not
parsed
.
path
:
...
...
@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
return
bucket_name
,
s3_path
def
s3_request
(
func
:
Callable
):
def
s3_request
(
func
):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@
wraps
(
func
)
def
wrapper
(
url
:
str
,
*
args
,
**
kwargs
):
def
wrapper
(
url
,
*
args
,
**
kwargs
):
try
:
return
func
(
url
,
*
args
,
**
kwargs
)
except
ClientError
as
exc
:
if
int
(
exc
.
response
[
"Error"
][
"Code"
])
==
404
:
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
url
))
raise
Environment
Error
(
"file {} not found"
.
format
(
url
))
else
:
raise
...
...
@@ -134,7 +136,7 @@ def s3_request(func: Callable):
@
s3_request
def
s3_etag
(
url
:
str
)
->
Optional
[
str
]
:
def
s3_etag
(
url
)
:
"""Check ETag on S3 object."""
s3_resource
=
boto3
.
resource
(
"s3"
)
bucket_name
,
s3_path
=
split_s3_path
(
url
)
...
...
@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]:
@
s3_request
def
s3_get
(
url
:
str
,
temp_file
:
IO
)
->
None
:
def
s3_get
(
url
,
temp_file
)
:
"""Pull a file directly from S3."""
s3_resource
=
boto3
.
resource
(
"s3"
)
bucket_name
,
s3_path
=
split_s3_path
(
url
)
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
def
http_get
(
url
:
str
,
temp_file
:
IO
)
->
None
:
def
http_get
(
url
,
temp_file
)
:
req
=
requests
.
get
(
url
,
stream
=
True
)
content_length
=
req
.
headers
.
get
(
'Content-Length'
)
total
=
int
(
content_length
)
if
content_length
is
not
None
else
None
...
...
@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None:
progress
.
close
()
def
get_from_cache
(
url
:
str
,
cache_dir
:
Union
[
str
,
Path
]
=
None
)
->
str
:
def
get_from_cache
(
url
,
cache_dir
=
None
)
:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if
cache_dir
is
None
:
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
cache_dir
):
os
.
makedirs
(
cache_dir
)
# Get eTag to add to filename, if it exists.
if
url
.
startswith
(
"s3://"
):
...
...
@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
logger
.
info
(
"creating metadata file for %s"
,
cache_path
)
meta
=
{
'url'
:
url
,
'etag'
:
etag
}
meta_path
=
cache_path
+
'.json'
with
open
(
meta_path
,
'w'
)
as
meta_file
:
with
open
(
meta_path
,
'w'
,
encoding
=
"utf-8"
)
as
meta_file
:
json
.
dump
(
meta
,
meta_file
)
logger
.
info
(
"removing temp file %s"
,
temp_file
.
name
)
...
...
@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
return
cache_path
def
read_set_from_file
(
filename
:
str
)
->
Set
[
str
]
:
def
read_set_from_file
(
filename
)
:
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
...
...
@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]:
return
collection
def
get_file_extension
(
path
:
str
,
dot
=
True
,
lower
:
bool
=
True
):
def
get_file_extension
(
path
,
dot
=
True
,
lower
=
True
):
ext
=
os
.
path
.
splitext
(
path
)[
1
]
ext
=
ext
if
dot
else
ext
[
1
:]
return
ext
.
lower
()
if
lower
else
ext
pytorch_pretrained_bert/modeling.py
View file @
448937c0
...
...
@@ -15,18 +15,18 @@
# limitations under the License.
"""PyTorch BERT model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
copy
import
json
import
math
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
shutil
import
sys
from
io
import
open
import
torch
from
torch
import
nn
...
...
@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
...
...
@@ -164,7 +164,8 @@ class BertConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
...
...
@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module):
def
__init__
(
self
,
config
):
super
(
BertIntermediate
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
\
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)):
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
intermediate_act_fn
=
config
.
hidden_act
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense
(
hidden_states
)
...
...
@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
def
__init__
(
self
,
config
):
super
(
BertPredictionHeadTransform
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
\
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)):
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
transform_act_fn
=
config
.
hidden_act
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
def
forward
(
self
,
hidden_states
):
...
...
@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
448937c0
...
...
@@ -24,6 +24,8 @@ import os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
import
torch
import
torch.nn
as
nn
...
...
@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
...
...
@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
...
...
@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
for
block
in
self
.
h
:
hidden_states
=
block
(
hidden_states
)
return
hidden_states
.
view
(
*
input_shape
,
hidden_states
.
size
(
-
1
))
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
return
hidden_states
.
view
(
*
output_shape
)
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
448937c0
...
...
@@ -27,6 +27,8 @@ import tarfile
import
tempfile
import
shutil
import
collections
import
sys
from
io
import
open
import
torch
import
torch.nn
as
nn
...
...
@@ -124,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
try
:
import
numpy
as
np
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
...
...
@@ -239,7 +241,8 @@ class TransfoXLConfig(object):
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
...
...
@@ -503,11 +506,12 @@ class RelMultiHeadAttn(nn.Module):
return
x
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
zero_pad
=
torch
.
zeros
(
(
x
.
size
(
0
),
1
,
*
x
.
size
()[
2
:]
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
zero_pad
_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
x_padded
=
x_padded
.
view
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
),
*
x
.
size
()[
2
:])
x_padded_shape
=
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
))
+
x
.
size
()[
2
:]
x_padded
=
x_padded
.
view
(
*
x_padded_shape
)
x
=
x_padded
[
1
:].
view_as
(
x
)
...
...
@@ -797,7 +801,8 @@ class AdaptiveEmbedding(nn.Module):
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
embed
=
emb_flat
.
view
(
*
inp
.
size
(),
self
.
d_proj
)
embed_shape
=
inp
.
size
()
+
(
self
.
d_proj
,)
embed
=
emb_flat
.
view
(
embed_shape
)
embed
.
mul_
(
self
.
emb_scale
)
...
...
@@ -905,7 +910,7 @@ class TransfoXLPreTrainedModel(nn.Module):
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
...
...
pytorch_pretrained_bert/tokenization.py
View file @
448937c0
...
...
@@ -14,14 +14,13 @@
# limitations under the License.
"""Tokenization classes."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
unicodedata
import
os
import
logging
import
os
import
unicodedata
from
io
import
open
from
.file_utils
import
cached_path
...
...
@@ -129,7 +128,7 @@ class BertTokenizer(object):
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
448937c0
...
...
@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
json
import
logging
import
os
import
re
import
json
import
sys
from
io
import
open
from
tqdm
import
tqdm
import
logging
from
.file_utils
import
cached_path
...
...
@@ -82,7 +88,7 @@ class OpenAIGPTTokenizer(object):
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
...
...
@@ -119,7 +125,7 @@ class OpenAIGPTTokenizer(object):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
nlp
=
spacy
.
load
(
'en'
,
disable
=
[
'parser'
,
'tagger'
,
'ner'
,
'textcat'
])
self
.
fix_text
=
ftfy
.
fix_text
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
...
...
@@ -196,7 +202,7 @@ class OpenAIGPTTokenizer(object):
def
convert_tokens_to_ids
(
self
,
tokens
):
"""Converts a sequence of tokens into ids using the vocab."""
ids
=
[]
if
isinstance
(
tokens
,
str
):
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
))
:
if
tokens
in
self
.
special_tokens
:
return
self
.
special_tokens
[
tokens
]
else
:
...
...
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
448937c0
...
...
@@ -16,16 +16,27 @@
""" Tokenization classes for Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl.
"""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
os
import
glob
import
logging
import
pickle
import
torch
import
os
import
sys
from
collections
import
Counter
,
OrderedDict
from
io
import
open
import
torch
import
numpy
as
np
from
.file_utils
import
cached_path
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
...
...
@@ -55,7 +66,7 @@ class TransfoXLTokenizer(object):
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
...
...
@@ -422,7 +433,7 @@ class TransfoXLCorpus(object):
# redirect to the cache, if necessary
try
:
resolved_corpus_file
=
cached_path
(
corpus_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
...
...
setup.py
View file @
448937c0
...
...
@@ -33,6 +33,7 @@ To create the package for pypi.
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
"""
from
io
import
open
from
setuptools
import
find_packages
,
setup
setup
(
...
...
@@ -58,7 +59,7 @@ setup(
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main"
,
]
},
python_requires
=
'>=3.5.0'
,
#
python_requires='>=3.5.0',
tests_require
=
[
'pytest'
],
classifiers
=
[
'Intended Audience :: Science/Research'
,
...
...
tests/tokenization_test.py
View file @
448937c0
...
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
os
import
unittest
from
io
import
open
from
pytorch_pretrained_bert.tokenization
import
(
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
,
_is_whitespace
,
_is_control
,
_is_punctuation
)
...
...
@@ -30,7 +31,7 @@ class TokenizationTest(unittest.TestCase):
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
]
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
)
as
vocab_writer
:
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
...
...
@@ -49,7 +50,7 @@ class TokenizationTest(unittest.TestCase):
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
]
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
)
as
vocab_writer
:
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment