Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f5397ffc
Commit
f5397ffc
authored
Sep 24, 2019
by
thomwolf
Browse files
update loading logics
parent
271f2136
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
35 deletions
+55
-35
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+1
-1
pytorch_transformers/file_utils.py
pytorch_transformers/file_utils.py
+1
-0
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+41
-30
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+12
-4
No files found.
pytorch_transformers/__init__.py
View file @
f5397ffc
...
...
@@ -163,7 +163,7 @@ if _tf_available and _torch_available:
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
)
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
)
def
is_torch_available
():
return
_torch_available
...
...
pytorch_transformers/file_utils.py
View file @
f5397ffc
...
...
@@ -49,6 +49,7 @@ except (AttributeError, ImportError):
PYTORCH_TRANSFORMERS_CACHE
=
PYTORCH_PRETRAINED_BERT_CACHE
# Kept for backward compatibility
WEIGHTS_NAME
=
"pytorch_model.bin"
TF2_WEIGHTS_NAME
=
'tf_model.h5'
TF_WEIGHTS_NAME
=
'model.ckpt'
CONFIG_NAME
=
"config.json"
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
f5397ffc
...
...
@@ -24,7 +24,7 @@ import os
import
tensorflow
as
tf
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -205,38 +205,49 @@ class TFPreTrainedModel(tf.keras.Model):
model_kwargs
=
kwargs
# Load model
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
from_pt
:
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
else
:
archive_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
is
not
None
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)):
# Load from a TF 2.0 checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)
elif
from_pt
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
raise
EnvironmentError
(
"Error no file named {} found in directory {}"
.
format
(
tuple
(
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
),
pretrained_model_name_or_path
))
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
raise
EnvironmentError
(
"Error file {} not found"
.
format
(
pretrained_model_name_or_path
))
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
raise
e
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
return
None
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
resolved_archive_file
=
None
# Instantiate model.
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
...
...
pytorch_transformers/modeling_utils.py
View file @
f5397ffc
...
...
@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -294,11 +294,19 @@ class PreTrainedModel(nn.Module):
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
from_tf
:
#
Directly l
oad from a T
ensorFlow
checkpoint
if
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
))
:
#
L
oad from a T
F 1.0
checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
else
:
elif
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)):
# Load from a TF 2.0 checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
raise
EnvironmentError
(
"Error no file named {} found in directory {}"
.
format
(
tuple
(
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
+
".index"
),
pretrained_model_name_or_path
))
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment