Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f5397ffc
Commit
f5397ffc
authored
Sep 24, 2019
by
thomwolf
Browse files
update loading logics
parent
271f2136
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
35 deletions
+55
-35
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+1
-1
pytorch_transformers/file_utils.py
pytorch_transformers/file_utils.py
+1
-0
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+41
-30
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+12
-4
No files found.
pytorch_transformers/__init__.py
View file @
f5397ffc
...
@@ -163,7 +163,7 @@ if _tf_available and _torch_available:
...
@@ -163,7 +163,7 @@ if _tf_available and _torch_available:
# Files and general utilities
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
)
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
)
def
is_torch_available
():
def
is_torch_available
():
return
_torch_available
return
_torch_available
...
...
pytorch_transformers/file_utils.py
View file @
f5397ffc
...
@@ -49,6 +49,7 @@ except (AttributeError, ImportError):
...
@@ -49,6 +49,7 @@ except (AttributeError, ImportError):
PYTORCH_TRANSFORMERS_CACHE
=
PYTORCH_PRETRAINED_BERT_CACHE
# Kept for backward compatibility
PYTORCH_TRANSFORMERS_CACHE
=
PYTORCH_PRETRAINED_BERT_CACHE
# Kept for backward compatibility
WEIGHTS_NAME
=
"pytorch_model.bin"
WEIGHTS_NAME
=
"pytorch_model.bin"
TF2_WEIGHTS_NAME
=
'tf_model.h5'
TF_WEIGHTS_NAME
=
'model.ckpt'
TF_WEIGHTS_NAME
=
'model.ckpt'
CONFIG_NAME
=
"config.json"
CONFIG_NAME
=
"config.json"
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
f5397ffc
...
@@ -24,7 +24,7 @@ import os
...
@@ -24,7 +24,7 @@ import os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -205,38 +205,49 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -205,38 +205,49 @@ class TFPreTrainedModel(tf.keras.Model):
model_kwargs
=
kwargs
model_kwargs
=
kwargs
# Load model
# Load model
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
is
not
None
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
from_pt
:
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
else
:
archive_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
archive_file
))
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)):
# Load from a TF 2.0 checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)
elif
from_pt
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
raise
EnvironmentError
(
"Error no file named {} found in directory {}"
.
format
(
tuple
(
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
),
pretrained_model_name_or_path
))
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
raise
EnvironmentError
(
"Error file {} not found"
.
format
(
pretrained_model_name_or_path
))
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
raise
e
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
else
:
logger
.
error
(
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
"Model name '{}' was not found in model name list ({}). "
archive_file
,
resolved_archive_file
))
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
return
None
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
resolved_archive_file
=
None
archive_file
,
resolved_archive_file
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
...
...
pytorch_transformers/modeling_utils.py
View file @
f5397ffc
...
@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
...
@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -294,11 +294,19 @@ class PreTrainedModel(nn.Module):
...
@@ -294,11 +294,19 @@ class PreTrainedModel(nn.Module):
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
from_tf
:
if
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
))
:
#
Directly l
oad from a T
ensorFlow
checkpoint
#
L
oad from a T
F 1.0
checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
else
:
elif
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)):
# Load from a TF 2.0 checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
raise
EnvironmentError
(
"Error no file named {} found in directory {}"
.
format
(
tuple
(
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
+
".index"
),
pretrained_model_name_or_path
))
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
archive_file
=
pretrained_model_name_or_path
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment