Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
448937c0
Commit
448937c0
authored
Feb 06, 2019
by
thomwolf
Browse files
python 2 compatibility
parent
ba37ddc5
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
245 additions
and
183 deletions
+245
-183
examples/eval_transfo_xl.py
examples/eval_transfo_xl.py
+16
-16
examples/run_classifier.py
examples/run_classifier.py
+16
-12
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+21
-22
examples/run_squad.py
examples/run_squad.py
+19
-11
examples/run_swag.py
examples/run_swag.py
+32
-24
pytorch_pretrained_bert/__main__.py
pytorch_pretrained_bert/__main__.py
+2
-2
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
...h_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
+8
-4
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+16
-10
pytorch_pretrained_bert/file_utils.py
pytorch_pretrained_bert/file_utils.py
+41
-40
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+18
-13
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+7
-3
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+12
-7
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+5
-6
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+11
-5
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+16
-5
setup.py
setup.py
+2
-1
tests/tokenization_test.py
tests/tokenization_test.py
+3
-2
No files found.
examples/eval_transfo_xl.py
View file @
448937c0
...
@@ -17,26 +17,26 @@
...
@@ -17,26 +17,26 @@
Adapted from https://github.com/kimiyoung/transformer-xl.
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
"""
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
os
import
sys
import
functools
import
functools
import
argparse
import
argparse
import
logging
import
time
import
time
import
math
import
math
import
sys
from
io
import
open
import
torch
import
torch
from
pytorch_pretrained_bert
import
TransfoXLModel
,
TransfoXLCorpus
from
pytorch_pretrained_bert
import
TransfoXLModel
,
TransfoXLCorpus
def
logging
(
s
,
log_path
,
print_
=
True
,
log_
=
True
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
if
print_
:
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
print
(
s
)
level
=
logging
.
INFO
)
if
log_
:
logger
=
logging
.
getLogger
(
__name__
)
with
open
(
log_path
,
'a+'
)
as
f_log
:
f_log
.
write
(
s
+
'
\n
'
)
def
get_logger
(
log_path
,
**
kwargs
):
return
functools
.
partial
(
logging
,
log_path
=
log_path
,
**
kwargs
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
...
@@ -71,8 +71,8 @@ assert args.ext_len >= 0, 'extended context length must be non-negative'
...
@@ -71,8 +71,8 @@ assert args.ext_len >= 0, 'extended context length must be non-negative'
device
=
torch
.
device
(
"cuda"
if
args
.
cuda
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
args
.
cuda
else
"cpu"
)
# Get logger
# Get logger
logging
=
get_logger
(
os
.
path
.
join
(
args
.
work_dir
,
'log.txt'
),
#
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
log_
=
not
args
.
no_log
)
#
log_=not args.no_log)
# Load dataset
# Load dataset
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
corpus
=
TransfoXLCorpus
.
from_pretrained
(
args
.
model_name
)
...
@@ -90,7 +90,7 @@ te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
...
@@ -90,7 +90,7 @@ te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
model
=
TransfoXLModel
.
from_pretrained
(
args
.
model_name
)
model
=
TransfoXLModel
.
from_pretrained
(
args
.
model_name
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
logg
ing
(
'Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'
.
format
(
logg
er
.
info
(
'Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'
.
format
(
args
.
batch_size
,
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
,
args
.
clamp_len
))
args
.
batch_size
,
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
,
args
.
clamp_len
))
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
...
@@ -116,7 +116,7 @@ def evaluate(eval_iter):
...
@@ -116,7 +116,7 @@ def evaluate(eval_iter):
total_loss
+=
seq_len
*
loss
.
item
()
total_loss
+=
seq_len
*
loss
.
item
()
total_len
+=
seq_len
total_len
+=
seq_len
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
logg
ing
(
'Time : {:.2f}s, {:.2f}ms/segment'
.
format
(
logg
er
.
info
(
'Time : {:.2f}s, {:.2f}ms/segment'
.
format
(
total_time
,
1000
*
total_time
/
(
idx
+
1
)))
total_time
,
1000
*
total_time
/
(
idx
+
1
)))
return
total_loss
/
total_len
return
total_loss
/
total_len
...
@@ -146,6 +146,6 @@ if valid_loss is not None:
...
@@ -146,6 +146,6 @@ if valid_loss is not None:
if
test_loss
is
not
None
:
if
test_loss
is
not
None
:
log_str
+=
format_log
(
test_loss
,
'test'
)
log_str
+=
format_log
(
test_loss
,
'test'
)
logg
ing
(
'='
*
100
)
logg
er
.
info
(
'='
*
100
)
logg
ing
(
log_str
)
logg
er
.
info
(
log_str
)
logg
ing
(
'='
*
100
)
logg
er
.
info
(
'='
*
100
)
examples/run_classifier.py
View file @
448937c0
...
@@ -15,26 +15,27 @@
...
@@ -15,26 +15,27 @@
# limitations under the License.
# limitations under the License.
"""BERT finetuning runner."""
"""BERT finetuning runner."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
csv
import
csv
import
os
import
logging
import
logging
import
argparse
import
os
import
random
import
random
from
tqdm
import
tqdm
,
trange
import
sys
from
io
import
open
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -91,10 +92,12 @@ class DataProcessor(object):
...
@@ -91,10 +92,12 @@ class DataProcessor(object):
@
classmethod
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
input_file
,
"r
b
"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
lines
=
[]
for
line
in
reader
:
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
lines
.
append
(
line
)
return
lines
return
lines
...
@@ -429,7 +432,8 @@ def main():
...
@@ -429,7 +432,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
:
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
task_name
=
args
.
task_name
.
lower
()
task_name
=
args
.
task_name
.
lower
()
...
@@ -451,7 +455,7 @@ def main():
...
@@ -451,7 +455,7 @@ def main():
# Prepare model
# Prepare model
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
),
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
)
)
,
num_labels
=
num_labels
)
num_labels
=
num_labels
)
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
...
...
examples/run_lm_finetuning.py
View file @
448937c0
...
@@ -15,26 +15,23 @@
...
@@ -15,26 +15,23 @@
# limitations under the License.
# limitations under the License.
"""BERT finetuning runner."""
"""BERT finetuning runner."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
division
from
__future__
import
print_function
import
os
import
logging
import
argparse
import
argparse
from
tqdm
import
tqdm
,
trange
import
logging
import
os
import
random
from
io
import
open
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
DataLoader
,
RandomSampler
from
torch.utils.data
import
DataLoader
,
Dataset
,
RandomSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
torch.utils.data
import
Dataset
import
random
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -185,16 +182,16 @@ class BERTDataset(Dataset):
...
@@ -185,16 +182,16 @@ class BERTDataset(Dataset):
if
self
.
line_buffer
is
None
:
if
self
.
line_buffer
is
None
:
# read first non-empty line of file
# read first non-empty line of file
while
t1
==
""
:
while
t1
==
""
:
t1
=
self
.
file
.
__next__
(
).
strip
()
t1
=
next
(
self
.
file
).
strip
()
t2
=
self
.
file
.
__next__
(
).
strip
()
t2
=
next
(
self
.
file
).
strip
()
else
:
else
:
# use t2 from previous iteration as new t1
# use t2 from previous iteration as new t1
t1
=
self
.
line_buffer
t1
=
self
.
line_buffer
t2
=
self
.
file
.
__next__
(
).
strip
()
t2
=
next
(
self
.
file
).
strip
()
# skip empty rows that are used for separating documents and keep track of current doc id
# skip empty rows that are used for separating documents and keep track of current doc id
while
t2
==
""
or
t1
==
""
:
while
t2
==
""
or
t1
==
""
:
t1
=
self
.
file
.
__next__
(
).
strip
()
t1
=
next
(
self
.
file
).
strip
()
t2
=
self
.
file
.
__next__
(
).
strip
()
t2
=
next
(
self
.
file
).
strip
()
self
.
current_doc
=
self
.
current_doc
+
1
self
.
current_doc
=
self
.
current_doc
+
1
self
.
line_buffer
=
t2
self
.
line_buffer
=
t2
...
@@ -228,15 +225,15 @@ class BERTDataset(Dataset):
...
@@ -228,15 +225,15 @@ class BERTDataset(Dataset):
def
get_next_line
(
self
):
def
get_next_line
(
self
):
""" Gets next line of random_file and starts over when reaching end of file"""
""" Gets next line of random_file and starts over when reaching end of file"""
try
:
try
:
line
=
self
.
random_file
.
__next__
(
).
strip
()
line
=
next
(
self
.
random_file
).
strip
()
#keep track of which document we are currently looking at to later avoid having the same doc as t1
#keep track of which document we are currently looking at to later avoid having the same doc as t1
if
line
==
""
:
if
line
==
""
:
self
.
current_random_doc
=
self
.
current_random_doc
+
1
self
.
current_random_doc
=
self
.
current_random_doc
+
1
line
=
self
.
random_file
.
__next__
(
).
strip
()
line
=
next
(
self
.
random_file
).
strip
()
except
StopIteration
:
except
StopIteration
:
self
.
random_file
.
close
()
self
.
random_file
.
close
()
self
.
random_file
=
open
(
self
.
corpus_path
,
"r"
,
encoding
=
self
.
encoding
)
self
.
random_file
=
open
(
self
.
corpus_path
,
"r"
,
encoding
=
self
.
encoding
)
line
=
self
.
random_file
.
__next__
(
).
strip
()
line
=
next
(
self
.
random_file
).
strip
()
return
line
return
line
...
@@ -425,6 +422,7 @@ def main():
...
@@ -425,6 +422,7 @@ def main():
help
=
"The output directory where the model checkpoints will be written."
)
help
=
"The output directory where the model checkpoints will be written."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--do_lower_case"
,
action
=
'store_true'
,
help
=
"Set this flag if you are using an uncased model."
)
parser
.
add_argument
(
"--max_seq_length"
,
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
default
=
128
,
type
=
int
,
type
=
int
,
...
@@ -513,7 +511,8 @@ def main():
...
@@ -513,7 +511,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
...
@@ -579,7 +578,7 @@ def main():
...
@@ -579,7 +578,7 @@ def main():
if
args
.
local_rank
==
-
1
:
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
train_dataset
)
train_sampler
=
RandomSampler
(
train_dataset
)
else
:
else
:
#TODO: check if this works with current data generator from disk that relies on
file.__next__
#TODO: check if this works with current data generator from disk that relies on
next(file)
# (it doesn't return item back by index)
# (it doesn't return item back by index)
train_sampler
=
DistributedSampler
(
train_dataset
)
train_sampler
=
DistributedSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
...
...
examples/run_squad.py
View file @
448937c0
...
@@ -15,29 +15,36 @@
...
@@ -15,29 +15,36 @@
# limitations under the License.
# limitations under the License.
"""Run BERT on SQuAD."""
"""Run BERT on SQuAD."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
argparse
import
collections
import
collections
import
logging
import
json
import
json
import
logging
import
math
import
math
import
os
import
os
import
random
import
random
import
pickle
import
sys
from
tqdm
import
tqdm
,
trange
from
io
import
open
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.
tokenization
import
whitespace_tokenize
,
BasicTokenizer
,
BertTokenizer
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.tokenization
import
(
BasicTokenizer
,
BertTokenizer
,
whitespace_tokenize
)
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -784,7 +791,8 @@ def main():
...
@@ -784,7 +791,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory () already exists and is not empty."
)
raise
ValueError
(
"Output directory () already exists and is not empty."
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
...
@@ -798,7 +806,7 @@ def main():
...
@@ -798,7 +806,7 @@ def main():
# Prepare model
# Prepare model
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
))
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
))
)
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
...
...
examples/run_swag.py
View file @
448937c0
...
@@ -15,22 +15,25 @@
...
@@ -15,22 +15,25 @@
# limitations under the License.
# limitations under the License.
"""BERT finetuning runner."""
"""BERT finetuning runner."""
import
argparse
import
csv
import
logging
import
logging
import
os
import
os
import
argparse
import
random
import
random
from
tqdm
import
tqdm
,
trange
import
sys
import
csv
from
io
import
open
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.modeling
import
BertForMultipleChoice
from
pytorch_pretrained_bert.modeling
import
BertForMultipleChoice
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.
file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
pytorch_pretrained_bert.
tokenization
import
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -65,17 +68,17 @@ class SwagExample(object):
...
@@ -65,17 +68,17 @@ class SwagExample(object):
def
__repr__
(
self
):
def
__repr__
(
self
):
l
=
[
l
=
[
f
"swag_id:
{
self
.
swag_id
}
"
,
"swag_id: {
}"
.
format
(
self
.
swag_id
)
,
f
"context_sentence:
{
self
.
context_sentence
}
"
,
"context_sentence: {
}"
.
format
(
self
.
context_sentence
)
,
f
"start_ending:
{
self
.
start_ending
}
"
,
"start_ending: {
}"
.
format
(
self
.
start_ending
)
,
f
"ending_0:
{
self
.
endings
[
0
]
}
"
,
"ending_0: {
}"
.
format
(
self
.
endings
[
0
]
)
,
f
"ending_1:
{
self
.
endings
[
1
]
}
"
,
"ending_1: {
}"
.
format
(
self
.
endings
[
1
]
)
,
f
"ending_2:
{
self
.
endings
[
2
]
}
"
,
"ending_2: {
}"
.
format
(
self
.
endings
[
2
]
)
,
f
"ending_3:
{
self
.
endings
[
3
]
}
"
,
"ending_3: {
}"
.
format
(
self
.
endings
[
3
]
)
,
]
]
if
self
.
label
is
not
None
:
if
self
.
label
is
not
None
:
l
.
append
(
f
"label:
{
self
.
label
}
"
)
l
.
append
(
"label: {
}"
.
format
(
self
.
label
)
)
return
", "
.
join
(
l
)
return
", "
.
join
(
l
)
...
@@ -102,7 +105,11 @@ class InputFeatures(object):
...
@@ -102,7 +105,11 @@ class InputFeatures(object):
def
read_swag_examples
(
input_file
,
is_training
):
def
read_swag_examples
(
input_file
,
is_training
):
with
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
reader
=
csv
.
reader
(
f
)
reader
=
csv
.
reader
(
f
)
lines
=
list
(
reader
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
if
is_training
and
lines
[
0
][
-
1
]
!=
'label'
:
if
is_training
and
lines
[
0
][
-
1
]
!=
'label'
:
raise
ValueError
(
raise
ValueError
(
...
@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
label
=
example
.
label
label
=
example
.
label
if
example_index
<
5
:
if
example_index
<
5
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
f
"swag_id:
{
example
.
swag_id
}
"
)
logger
.
info
(
"swag_id: {
}"
.
format
(
example
.
swag_id
)
)
for
choice_idx
,
(
tokens
,
input_ids
,
input_mask
,
segment_ids
)
in
enumerate
(
choices_features
):
for
choice_idx
,
(
tokens
,
input_ids
,
input_mask
,
segment_ids
)
in
enumerate
(
choices_features
):
logger
.
info
(
f
"choice:
{
choice_idx
}
"
)
logger
.
info
(
"choice: {
}"
.
format
(
choice_idx
)
)
logger
.
info
(
f
"tokens:
{
' '
.
join
(
tokens
)
}
"
)
logger
.
info
(
"tokens: {
}"
.
format
(
' '
.
join
(
tokens
)
)
)
logger
.
info
(
f
"input_ids:
{
' '
.
join
(
map
(
str
,
input_ids
))
}
"
)
logger
.
info
(
"input_ids: {
}"
.
format
(
' '
.
join
(
map
(
str
,
input_ids
))
)
)
logger
.
info
(
f
"input_mask:
{
' '
.
join
(
map
(
str
,
input_mask
))
}
"
)
logger
.
info
(
"input_mask: {
}"
.
format
(
' '
.
join
(
map
(
str
,
input_mask
))
)
)
logger
.
info
(
f
"segment_ids:
{
' '
.
join
(
map
(
str
,
segment_ids
))
}
"
)
logger
.
info
(
"segment_ids: {
}"
.
format
(
' '
.
join
(
map
(
str
,
segment_ids
))
)
)
if
is_training
:
if
is_training
:
logger
.
info
(
f
"label:
{
label
}
"
)
logger
.
info
(
"label: {
}"
.
format
(
label
)
)
features
.
append
(
features
.
append
(
InputFeatures
(
InputFeatures
(
...
@@ -349,7 +356,8 @@ def main():
...
@@ -349,7 +356,8 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
...
@@ -362,7 +370,7 @@ def main():
...
@@ -362,7 +370,7 @@ def main():
# Prepare model
# Prepare model
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
),
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
)
)
,
num_choices
=
4
)
num_choices
=
4
)
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
...
...
pytorch_pretrained_bert/__main__.py
View file @
448937c0
...
@@ -15,7 +15,7 @@ def main():
...
@@ -15,7 +15,7 @@ def main():
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
try
:
try
:
import
tensorflow
as
tf
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
"https://www.tensorflow.org/install/ for installation instructions."
)
...
@@ -43,7 +43,7 @@ def main():
...
@@ -43,7 +43,7 @@ def main():
else
:
else
:
try
:
try
:
import
tensorflow
as
tf
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
"https://www.tensorflow.org/install/ for installation instructions."
)
...
...
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
View file @
448937c0
...
@@ -14,14 +14,18 @@
...
@@ -14,14 +14,18 @@
# limitations under the License.
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
"""Convert OpenAI GPT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
argparse
from
io
import
open
import
torch
import
torch
from
pytorch_pretrained_bert.modeling_openai
import
load_tf_weights_in_openai_gpt
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_pretrained_bert.modeling_openai
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
load_tf_weights_in_openai_gpt
)
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
# Construct model
# Construct model
...
...
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
448937c0
...
@@ -14,25 +14,31 @@
...
@@ -14,25 +14,31 @@
# limitations under the License.
# limitations under the License.
"""Convert Transformer XL checkpoint and datasets."""
"""Convert Transformer XL checkpoint and datasets."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
os
import
os
import
sys
import
sys
import
argparse
from
io
import
open
import
pickle
import
tensorflow
as
tf
import
torch
import
torch
import
numpy
as
np
from
pytorch_pretrained_bert.modeling_transfo_xl
import
TransfoXLConfig
,
TransfoXLModel
,
CONFIG_NAME
,
WEIGHTS_NAME
,
load_tf_weights_in_transfo_xl
import
pytorch_pretrained_bert.tokenization_transfo_xl
as
data_utils
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
VOCAB_NAME
,
CORPUS_NAME
from
pytorch_pretrained_bert.modeling_transfo_xl
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
TransfoXLConfig
,
TransfoXLModel
,
load_tf_weights_in_transfo_xl
)
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
(
CORPUS_NAME
,
VOCAB_NAME
)
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
# We do this to be able to load the python 2 datasets pickles
# We do this to be able to load the python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
import
pytorch_pretrained_bert.tokenization_transfo_xl
as
data_utils
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Corpus
=
data_utils
.
TransfoXLCorpus
data_utils
.
Corpus
=
data_utils
.
TransfoXLCorpus
sys
.
modules
[
'data_utils'
]
=
data_utils
sys
.
modules
[
'data_utils'
]
=
data_utils
...
...
pytorch_pretrained_bert/file_utils.py
View file @
448937c0
...
@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache.
...
@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
Copyright by the AllenNLP authors.
"""
"""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
os
import
json
import
logging
import
logging
import
os
import
shutil
import
shutil
import
tempfile
import
tempfile
import
json
from
urllib.parse
import
urlparse
from
pathlib
import
Path
from
typing
import
Optional
,
Tuple
,
Union
,
IO
,
Callable
,
Set
from
hashlib
import
sha256
from
functools
import
wraps
from
functools
import
wraps
from
hashlib
import
sha256
from
tqdm
import
tqdm
from
io
import
open
import
boto3
import
boto3
from
botocore.exceptions
import
ClientError
import
requests
import
requests
from
botocore.exceptions
import
ClientError
from
tqdm
import
tqdm
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
try
:
from
urllib.parse
import
urlparse
except
ImportError
:
from
urlparse
import
urlparse
PYTORCH_PRETRAINED_BERT_CACHE
=
Path
(
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
try
:
from
pathlib
import
Path
PYTORCH_PRETRAINED_BERT_CACHE
=
Path
(
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
Path
.
home
()
/
'.pytorch_pretrained_bert'
))
Path
.
home
()
/
'.pytorch_pretrained_bert'
))
except
ImportError
:
PYTORCH_PRETRAINED_BERT_CACHE
=
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
os
.
path
.
join
(
os
.
path
.
expanduser
(
"~"
),
'.pytorch_pretrained_bert'
))
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
def
url_to_filename
(
url
:
str
,
etag
:
str
=
None
)
->
str
:
def
url_to_filename
(
url
,
etag
=
None
)
:
"""
"""
Convert `url` into a hashed filename in a repeatable way.
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
If `etag` is specified, append its hash to the url's, delimited
...
@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str:
...
@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str:
return
filename
return
filename
def
filename_to_url
(
filename
:
str
,
cache_dir
:
Union
[
str
,
Path
]
=
None
)
->
Tuple
[
str
,
str
]
:
def
filename_to_url
(
filename
,
cache_dir
=
None
)
:
"""
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``
FileNotFound
Error`` if `filename` or its stored metadata do not exist.
Raise ``
Environment
Error`` if `filename` or its stored metadata do not exist.
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
if
not
os
.
path
.
exists
(
cache_path
):
if
not
os
.
path
.
exists
(
cache_path
):
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
cache_path
))
raise
Environment
Error
(
"file {} not found"
.
format
(
cache_path
))
meta_path
=
cache_path
+
'.json'
meta_path
=
cache_path
+
'.json'
if
not
os
.
path
.
exists
(
meta_path
):
if
not
os
.
path
.
exists
(
meta_path
):
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
meta_path
))
raise
Environment
Error
(
"file {} not found"
.
format
(
meta_path
))
with
open
(
meta_path
)
as
meta_file
:
with
open
(
meta_path
,
encoding
=
"utf-8"
)
as
meta_file
:
metadata
=
json
.
load
(
meta_file
)
metadata
=
json
.
load
(
meta_file
)
url
=
metadata
[
'url'
]
url
=
metadata
[
'url'
]
etag
=
metadata
[
'etag'
]
etag
=
metadata
[
'etag'
]
...
@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
...
@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
return
url
,
etag
return
url
,
etag
def
cached_path
(
url_or_filename
:
Union
[
str
,
Path
],
cache_dir
:
Union
[
str
,
Path
]
=
None
)
->
str
:
def
cached_path
(
url_or_filename
,
cache_dir
=
None
)
:
"""
"""
Given something that might be a URL (or might be a local path),
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
determine which. If it's a URL, download the file and cache it, and
...
@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
...
@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
isinstance
(
url_or_filename
,
Path
):
url_or_filename
=
str
(
url_or_filename
)
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
parsed
=
urlparse
(
url_or_filename
)
parsed
=
urlparse
(
url_or_filename
)
...
@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
...
@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
return
url_or_filename
return
url_or_filename
elif
parsed
.
scheme
==
''
:
elif
parsed
.
scheme
==
''
:
# File, but it doesn't exist.
# File, but it doesn't exist.
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
url_or_filename
))
raise
Environment
Error
(
"file {} not found"
.
format
(
url_or_filename
))
else
:
else
:
# Something unknown
# Something unknown
raise
ValueError
(
"unable to parse {} as a URL or as a local path"
.
format
(
url_or_filename
))
raise
ValueError
(
"unable to parse {} as a URL or as a local path"
.
format
(
url_or_filename
))
def
split_s3_path
(
url
:
str
)
->
Tuple
[
str
,
str
]
:
def
split_s3_path
(
url
)
:
"""Split a full s3 path into the bucket name and path."""
"""Split a full s3 path into the bucket name and path."""
parsed
=
urlparse
(
url
)
parsed
=
urlparse
(
url
)
if
not
parsed
.
netloc
or
not
parsed
.
path
:
if
not
parsed
.
netloc
or
not
parsed
.
path
:
...
@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
...
@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
return
bucket_name
,
s3_path
return
bucket_name
,
s3_path
def
s3_request
(
func
:
Callable
):
def
s3_request
(
func
):
"""
"""
Wrapper function for s3 requests in order to create more helpful error
Wrapper function for s3 requests in order to create more helpful error
messages.
messages.
"""
"""
@
wraps
(
func
)
@
wraps
(
func
)
def
wrapper
(
url
:
str
,
*
args
,
**
kwargs
):
def
wrapper
(
url
,
*
args
,
**
kwargs
):
try
:
try
:
return
func
(
url
,
*
args
,
**
kwargs
)
return
func
(
url
,
*
args
,
**
kwargs
)
except
ClientError
as
exc
:
except
ClientError
as
exc
:
if
int
(
exc
.
response
[
"Error"
][
"Code"
])
==
404
:
if
int
(
exc
.
response
[
"Error"
][
"Code"
])
==
404
:
raise
FileNotFound
Error
(
"file {} not found"
.
format
(
url
))
raise
Environment
Error
(
"file {} not found"
.
format
(
url
))
else
:
else
:
raise
raise
...
@@ -134,7 +136,7 @@ def s3_request(func: Callable):
...
@@ -134,7 +136,7 @@ def s3_request(func: Callable):
@
s3_request
@
s3_request
def
s3_etag
(
url
:
str
)
->
Optional
[
str
]
:
def
s3_etag
(
url
)
:
"""Check ETag on S3 object."""
"""Check ETag on S3 object."""
s3_resource
=
boto3
.
resource
(
"s3"
)
s3_resource
=
boto3
.
resource
(
"s3"
)
bucket_name
,
s3_path
=
split_s3_path
(
url
)
bucket_name
,
s3_path
=
split_s3_path
(
url
)
...
@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]:
...
@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]:
@
s3_request
@
s3_request
def
s3_get
(
url
:
str
,
temp_file
:
IO
)
->
None
:
def
s3_get
(
url
,
temp_file
)
:
"""Pull a file directly from S3."""
"""Pull a file directly from S3."""
s3_resource
=
boto3
.
resource
(
"s3"
)
s3_resource
=
boto3
.
resource
(
"s3"
)
bucket_name
,
s3_path
=
split_s3_path
(
url
)
bucket_name
,
s3_path
=
split_s3_path
(
url
)
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
def
http_get
(
url
:
str
,
temp_file
:
IO
)
->
None
:
def
http_get
(
url
,
temp_file
)
:
req
=
requests
.
get
(
url
,
stream
=
True
)
req
=
requests
.
get
(
url
,
stream
=
True
)
content_length
=
req
.
headers
.
get
(
'Content-Length'
)
content_length
=
req
.
headers
.
get
(
'Content-Length'
)
total
=
int
(
content_length
)
if
content_length
is
not
None
else
None
total
=
int
(
content_length
)
if
content_length
is
not
None
else
None
...
@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None:
...
@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None:
progress
.
close
()
progress
.
close
()
def
get_from_cache
(
url
:
str
,
cache_dir
:
Union
[
str
,
Path
]
=
None
)
->
str
:
def
get_from_cache
(
url
,
cache_dir
=
None
)
:
"""
"""
Given a URL, look for the corresponding dataset in the local cache.
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
If it's not there, download it. Then return the path to the cached file.
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
cache_dir
):
os
.
makedirs
(
cache_dir
)
# Get eTag to add to filename, if it exists.
# Get eTag to add to filename, if it exists.
if
url
.
startswith
(
"s3://"
):
if
url
.
startswith
(
"s3://"
):
...
@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
...
@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
logger
.
info
(
"creating metadata file for %s"
,
cache_path
)
logger
.
info
(
"creating metadata file for %s"
,
cache_path
)
meta
=
{
'url'
:
url
,
'etag'
:
etag
}
meta
=
{
'url'
:
url
,
'etag'
:
etag
}
meta_path
=
cache_path
+
'.json'
meta_path
=
cache_path
+
'.json'
with
open
(
meta_path
,
'w'
)
as
meta_file
:
with
open
(
meta_path
,
'w'
,
encoding
=
"utf-8"
)
as
meta_file
:
json
.
dump
(
meta
,
meta_file
)
json
.
dump
(
meta
,
meta_file
)
logger
.
info
(
"removing temp file %s"
,
temp_file
.
name
)
logger
.
info
(
"removing temp file %s"
,
temp_file
.
name
)
...
@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
...
@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
return
cache_path
return
cache_path
def
read_set_from_file
(
filename
:
str
)
->
Set
[
str
]
:
def
read_set_from_file
(
filename
)
:
'''
'''
Extract a de-duped collection (set) of text from a file.
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
Expected file format is one item per line.
...
@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]:
...
@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]:
return
collection
return
collection
def
get_file_extension
(
path
:
str
,
dot
=
True
,
lower
:
bool
=
True
):
def
get_file_extension
(
path
,
dot
=
True
,
lower
=
True
):
ext
=
os
.
path
.
splitext
(
path
)[
1
]
ext
=
os
.
path
.
splitext
(
path
)[
1
]
ext
=
ext
if
dot
else
ext
[
1
:]
ext
=
ext
if
dot
else
ext
[
1
:]
return
ext
.
lower
()
if
lower
else
ext
return
ext
.
lower
()
if
lower
else
ext
pytorch_pretrained_bert/modeling.py
View file @
448937c0
...
@@ -15,18 +15,18 @@
...
@@ -15,18 +15,18 @@
# limitations under the License.
# limitations under the License.
"""PyTorch BERT model."""
"""PyTorch BERT model."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
division
from
__future__
import
print_function
import
os
import
copy
import
copy
import
json
import
json
import
math
import
logging
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tarfile
import
tempfile
import
tempfile
import
shutil
import
sys
from
io
import
open
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
...
@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
import
re
import
re
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
raise
...
@@ -164,7 +164,8 @@ class BertConfig(object):
...
@@ -164,7 +164,8 @@ class BertConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
initializing all weight matrices.
"""
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
for
key
,
value
in
json_config
.
items
():
...
@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module):
...
@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BertIntermediate
,
self
).
__init__
()
super
(
BertIntermediate
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
\
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)):
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
intermediate_act_fn
=
config
.
hidden_act
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dense
(
hidden_states
)
...
@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BertPredictionHeadTransform
,
self
).
__init__
()
super
(
BertPredictionHeadTransform
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
\
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)):
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
transform_act_fn
=
config
.
hidden_act
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module):
...
@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"We assumed '{}' was a path or url but couldn't find any file "
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
448937c0
...
@@ -24,6 +24,8 @@ import os
...
@@ -24,6 +24,8 @@ import os
import
shutil
import
shutil
import
tarfile
import
tarfile
import
tempfile
import
tempfile
import
sys
from
io
import
open
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
...
@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
initializing all weight matrices.
"""
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
for
key
,
value
in
json_config
.
items
():
...
@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"We assumed '{}' was a path or url but couldn't find any file "
...
@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
for
block
in
self
.
h
:
for
block
in
self
.
h
:
hidden_states
=
block
(
hidden_states
)
hidden_states
=
block
(
hidden_states
)
return
hidden_states
.
view
(
*
input_shape
,
hidden_states
.
size
(
-
1
))
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
return
hidden_states
.
view
(
*
output_shape
)
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
448937c0
...
@@ -27,6 +27,8 @@ import tarfile
...
@@ -27,6 +27,8 @@ import tarfile
import
tempfile
import
tempfile
import
shutil
import
shutil
import
collections
import
collections
import
sys
from
io
import
open
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -124,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
...
@@ -124,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
try
:
try
:
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
except
ModuleNotFound
Error
:
except
Import
Error
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
raise
...
@@ -239,7 +241,8 @@ class TransfoXLConfig(object):
...
@@ -239,7 +241,8 @@ class TransfoXLConfig(object):
proj_init_std: parameters initialized by N(0, init_std)
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
for
key
,
value
in
json_config
.
items
():
...
@@ -503,11 +506,12 @@ class RelMultiHeadAttn(nn.Module):
...
@@ -503,11 +506,12 @@ class RelMultiHeadAttn(nn.Module):
return
x
return
x
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
zero_pad
=
torch
.
zeros
(
(
x
.
size
(
0
),
1
,
*
x
.
size
()[
2
:]
),
zero_pad
_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
device
=
x
.
device
,
dtype
=
x
.
dtype
)
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
x_padded
=
x_padded
.
view
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
),
*
x
.
size
()[
2
:])
x_padded_shape
=
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
))
+
x
.
size
()[
2
:]
x_padded
=
x_padded
.
view
(
*
x_padded_shape
)
x
=
x_padded
[
1
:].
view_as
(
x
)
x
=
x_padded
[
1
:].
view_as
(
x
)
...
@@ -797,7 +801,8 @@ class AdaptiveEmbedding(nn.Module):
...
@@ -797,7 +801,8 @@ class AdaptiveEmbedding(nn.Module):
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
embed
=
emb_flat
.
view
(
*
inp
.
size
(),
self
.
d_proj
)
embed_shape
=
inp
.
size
()
+
(
self
.
d_proj
,)
embed
=
emb_flat
.
view
(
embed_shape
)
embed
.
mul_
(
self
.
emb_scale
)
embed
.
mul_
(
self
.
emb_scale
)
...
@@ -905,7 +910,7 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -905,7 +910,7 @@ class TransfoXLPreTrainedModel(nn.Module):
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
...
...
pytorch_pretrained_bert/tokenization.py
View file @
448937c0
...
@@ -14,14 +14,13 @@
...
@@ -14,14 +14,13 @@
# limitations under the License.
# limitations under the License.
"""Tokenization classes."""
"""Tokenization classes."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
collections
import
unicodedata
import
os
import
logging
import
logging
import
os
import
unicodedata
from
io
import
open
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
...
@@ -129,7 +128,7 @@ class BertTokenizer(object):
...
@@ -129,7 +128,7 @@ class BertTokenizer(object):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"We assumed '{}' was a path or url but couldn't find any file "
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
448937c0
...
@@ -13,11 +13,17 @@
...
@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
json
import
logging
import
os
import
os
import
re
import
re
import
json
import
sys
from
io
import
open
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
logging
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
...
@@ -82,7 +88,7 @@ class OpenAIGPTTokenizer(object):
...
@@ -82,7 +88,7 @@ class OpenAIGPTTokenizer(object):
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
...
@@ -119,7 +125,7 @@ class OpenAIGPTTokenizer(object):
...
@@ -119,7 +125,7 @@ class OpenAIGPTTokenizer(object):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
nlp
=
spacy
.
load
(
'en'
,
disable
=
[
'parser'
,
'tagger'
,
'ner'
,
'textcat'
])
self
.
nlp
=
spacy
.
load
(
'en'
,
disable
=
[
'parser'
,
'tagger'
,
'ner'
,
'textcat'
])
self
.
fix_text
=
ftfy
.
fix_text
self
.
fix_text
=
ftfy
.
fix_text
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
...
@@ -196,7 +202,7 @@ class OpenAIGPTTokenizer(object):
...
@@ -196,7 +202,7 @@ class OpenAIGPTTokenizer(object):
def
convert_tokens_to_ids
(
self
,
tokens
):
def
convert_tokens_to_ids
(
self
,
tokens
):
"""Converts a sequence of tokens into ids using the vocab."""
"""Converts a sequence of tokens into ids using the vocab."""
ids
=
[]
ids
=
[]
if
isinstance
(
tokens
,
str
):
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
))
:
if
tokens
in
self
.
special_tokens
:
if
tokens
in
self
.
special_tokens
:
return
self
.
special_tokens
[
tokens
]
return
self
.
special_tokens
[
tokens
]
else
:
else
:
...
...
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
448937c0
...
@@ -16,16 +16,27 @@
...
@@ -16,16 +16,27 @@
""" Tokenization classes for Transformer XL model.
""" Tokenization classes for Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl.
Adapted from https://github.com/kimiyoung/transformer-xl.
"""
"""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
os
import
glob
import
glob
import
logging
import
logging
import
pickle
import
os
import
torch
import
sys
from
collections
import
Counter
,
OrderedDict
from
collections
import
Counter
,
OrderedDict
from
io
import
open
import
torch
import
numpy
as
np
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
...
@@ -55,7 +66,7 @@ class TransfoXLTokenizer(object):
...
@@ -55,7 +66,7 @@ class TransfoXLTokenizer(object):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"We assumed '{}' was a path or url but couldn't find files {} "
...
@@ -422,7 +433,7 @@ class TransfoXLCorpus(object):
...
@@ -422,7 +433,7 @@ class TransfoXLCorpus(object):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_corpus_file
=
cached_path
(
corpus_file
,
cache_dir
=
cache_dir
)
resolved_corpus_file
=
cached_path
(
corpus_file
,
cache_dir
=
cache_dir
)
except
FileNotFound
Error
:
except
Environment
Error
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"We assumed '{}' was a path or url but couldn't find files {} "
...
...
setup.py
View file @
448937c0
...
@@ -33,6 +33,7 @@ To create the package for pypi.
...
@@ -33,6 +33,7 @@ To create the package for pypi.
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
"""
"""
from
io
import
open
from
setuptools
import
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
setup
(
setup
(
...
@@ -58,7 +59,7 @@ setup(
...
@@ -58,7 +59,7 @@ setup(
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main"
,
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main"
,
]
]
},
},
python_requires
=
'>=3.5.0'
,
#
python_requires='>=3.5.0',
tests_require
=
[
'pytest'
],
tests_require
=
[
'pytest'
],
classifiers
=
[
classifiers
=
[
'Intended Audience :: Science/Research'
,
'Intended Audience :: Science/Research'
,
...
...
tests/tokenization_test.py
View file @
448937c0
...
@@ -18,6 +18,7 @@ from __future__ import print_function
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
os
import
os
import
unittest
import
unittest
from
io
import
open
from
pytorch_pretrained_bert.tokenization
import
(
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
,
from
pytorch_pretrained_bert.tokenization
import
(
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
,
_is_whitespace
,
_is_control
,
_is_punctuation
)
_is_whitespace
,
_is_control
,
_is_punctuation
)
...
@@ -30,7 +31,7 @@ class TokenizationTest(unittest.TestCase):
...
@@ -30,7 +31,7 @@ class TokenizationTest(unittest.TestCase):
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
"##ing"
,
","
]
]
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
)
as
vocab_writer
:
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
vocab_file
=
vocab_writer
.
name
...
@@ -49,7 +50,7 @@ class TokenizationTest(unittest.TestCase):
...
@@ -49,7 +50,7 @@ class TokenizationTest(unittest.TestCase):
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
"##ing"
,
","
]
]
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
)
as
vocab_writer
:
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
vocab_file
=
vocab_writer
.
name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment