Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c37815f1
Commit
c37815f1
authored
Dec 20, 2019
by
thomwolf
Browse files
clean up PT <=> TF 2.0 conversion and config loading
parent
73fcebf7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
14 deletions
+29
-14
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+5
-4
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+12
-5
transformers/modeling_utils.py
transformers/modeling_utils.py
+12
-5
No files found.
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
c37815f1
...
@@ -32,7 +32,7 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
...
@@ -32,7 +32,7 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForSequenceClassification
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
,
TFAlbertForMaskedLM
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
,
TFAlbertForMaskedLM
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
T5Config
,
TFT5WithLMHeadModel
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
)
T5Config
,
TFT5WithLMHeadModel
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
)
...
@@ -47,7 +47,7 @@ if is_torch_available():
...
@@ -47,7 +47,7 @@ if is_torch_available():
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
...
@@ -59,7 +59,7 @@ else:
...
@@ -59,7 +59,7 @@ else:
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
...
@@ -70,7 +70,7 @@ else:
...
@@ -70,7 +70,7 @@ else:
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
None
,
None
)
...
@@ -93,6 +93,7 @@ MODEL_CLASSES = {
...
@@ -93,6 +93,7 @@ MODEL_CLASSES = {
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'albert'
:
(
AlbertConfig
,
TFAlbertForMaskedLM
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'albert'
:
(
AlbertConfig
,
TFAlbertForMaskedLM
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
't5'
:
(
T5Config
,
TFT5WithLMHeadModel
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
),
't5'
:
(
T5Config
,
TFT5WithLMHeadModel
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
),
...
...
transformers/modeling_tf_utils.py
View file @
c37815f1
...
@@ -184,7 +184,9 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -184,7 +184,9 @@ class TFPreTrainedModel(tf.keras.Model):
model_args: (`optional`) Sequence of positional arguments:
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
config: (`optional`) one of:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
...
@@ -236,10 +238,11 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -236,10 +238,11 @@ class TFPreTrainedModel(tf.keras.Model):
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
# Load config
# Load config if we don't provide a configuration
if
config
is
None
:
if
not
isinstance
(
config
,
PretrainedConfig
):
config_path
=
config
if
config
is
not
None
else
pretrained_model_name_or_path
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or
_path
,
*
model_args
,
config
_path
,
*
model_args
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
force_download
=
force_download
,
force_download
=
force_download
,
resume_download
=
resume_download
,
resume_download
=
resume_download
,
...
@@ -310,7 +313,11 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -310,7 +313,11 @@ class TFPreTrainedModel(tf.keras.Model):
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
try
:
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
except
OSError
:
raise
OSError
(
"Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# Make sure restore ops are run
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# Make sure restore ops are run
...
...
transformers/modeling_utils.py
View file @
c37815f1
...
@@ -281,7 +281,9 @@ class PreTrainedModel(nn.Module):
...
@@ -281,7 +281,9 @@ class PreTrainedModel(nn.Module):
model_args: (`optional`) Sequence of positional arguments:
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
config: (`optional`) one of:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
...
@@ -336,10 +338,11 @@ class PreTrainedModel(nn.Module):
...
@@ -336,10 +338,11 @@ class PreTrainedModel(nn.Module):
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
# Load config
# Load config if we don't provide a configuration
if
config
is
None
:
if
not
isinstance
(
config
,
PretrainedConfig
):
config_path
=
config
if
config
is
not
None
else
pretrained_model_name_or_path
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or
_path
,
*
model_args
,
config
_path
,
*
model_args
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
force_download
=
force_download
,
force_download
=
force_download
,
resume_download
=
resume_download
,
resume_download
=
resume_download
,
...
@@ -408,7 +411,11 @@ class PreTrainedModel(nn.Module):
...
@@ -408,7 +411,11 @@ class PreTrainedModel(nn.Module):
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
try
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
except
:
raise
OSError
(
"Unable to load weights from pytorch checkpoint file. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
missing_keys
=
[]
missing_keys
=
[]
unexpected_keys
=
[]
unexpected_keys
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment