Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d6f06c03
Commit
d6f06c03
authored
Nov 30, 2018
by
thomwolf
Browse files
fixed loading pre-trained tokenizer from directory
parent
532a81d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
12 deletions
+15
-12
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+1
-1
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+14
-11
No files found.
pytorch_pretrained_bert/modeling.py
View file @
d6f06c03
...
@@ -478,7 +478,7 @@ class PreTrainedBertModel(nn.Module):
...
@@ -478,7 +478,7 @@ class PreTrainedBertModel(nn.Module):
"associated to this path or url."
.
format
(
"associated to this path or url."
.
format
(
pretrained_model_name
,
pretrained_model_name
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_nam
e
))
archive_fil
e
))
return
None
return
None
if
resolved_archive_file
==
archive_file
:
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
...
...
pytorch_pretrained_bert/tokenization.py
View file @
d6f06c03
...
@@ -39,6 +39,7 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
...
@@ -39,6 +39,7 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
}
}
VOCAB_NAME
=
'vocab.txt'
def
load_vocab
(
vocab_file
):
def
load_vocab
(
vocab_file
):
...
@@ -100,7 +101,7 @@ class BertTokenizer(object):
...
@@ -100,7 +101,7 @@ class BertTokenizer(object):
return
tokens
return
tokens
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
do_lower_case
=
True
):
def
from_pretrained
(
cls
,
pretrained_model_name
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
...
@@ -109,16 +110,11 @@ class BertTokenizer(object):
...
@@ -109,16 +110,11 @@ class BertTokenizer(object):
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name
]
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name
]
else
:
else
:
vocab_file
=
pretrained_model_name
vocab_file
=
pretrained_model_name
if
os
.
path
.
isdir
(
vocab_file
):
vocab_file
=
os
.
path
.
join
(
vocab_file
,
VOCAB_NAME
)
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
)
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
do_lower_case
)
except
FileNotFoundError
:
except
FileNotFoundError
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
...
@@ -126,8 +122,15 @@ class BertTokenizer(object):
...
@@ -126,8 +122,15 @@ class BertTokenizer(object):
"associated to this path or url."
.
format
(
"associated to this path or url."
.
format
(
pretrained_model_name
,
pretrained_model_name
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name
))
vocab_file
))
tokenizer
=
None
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
*
inputs
,
**
kwargs
)
return
tokenizer
return
tokenizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment