Unverified Commit 84f6bee5 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

PT <-> TF for composite models (#19732)



* First step of PT->TF for composite models

* Update the tests

* For VisionEncoderDecoderModel

* Fix

* Fix

* Add comment

* Fix

* clean up import

* Save memory

* For (TF)EncoderDecoderModel

* For (TF)EncoderDecoderModel
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 12ce2941
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
# limitations under the License. # limitations under the License.
""" Classes to support Encoder-Decoder architectures""" """ Classes to support Encoder-Decoder architectures"""
import gc
import os
import tempfile
import warnings import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -267,7 +271,96 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -267,7 +271,96 @@ class EncoderDecoderModel(PreTrainedModel):
return self.decoder.set_output_embeddings(new_embeddings) return self.decoder.set_output_embeddings(new_embeddings)
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Example:
```python
>>> from transformers import EncoderDecoderModel
>>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
```"""
from_tf = kwargs.pop("from_tf", False)
if from_tf:
from transformers import TFEncoderDecoderModel
# a workaround to load from tensorflow checkpoint
# Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get
# extended before saving those components. For example, The name of `_tf_model.encoder.vit` is
# `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The
# [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`,
# which should not occur when we want to save the components alone.
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = _tf_model.config
# Using `tf_model` instead
encoder = _tf_model.encoder.__class__(_tf_model.config.encoder)
decoder = _tf_model.decoder.__class__(_tf_model.config.decoder)
# Make sure models are built
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
# Get the variable correspondence between `_tf_model` and `encoder` and `decoder`
encoder_variables = {}
for v in encoder.trainable_variables + encoder.non_trainable_variables:
encoder_variables["/".join(v.name.split("/")[1:])] = v
decoder_variables = {}
for v in decoder.trainable_variables + decoder.non_trainable_variables:
decoder_variables["/".join(v.name.split("/")[1:])] = v
_encoder_variables = {}
for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables:
_encoder_variables["/".join(v.name.split("/")[2:])] = v
_decoder_variables = {}
for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables:
_decoder_variables["/".join(v.name.split("/")[2:])] = v
# assign weight values to `encoder` and `decoder` from `_tf_model`
for name, v in encoder_variables.items():
v.assign(_encoder_variables[name])
for name, v in decoder_variables.items():
v.assign(_decoder_variables[name])
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
# Deal with `enc_to_dec_proj`
if hasattr(_tf_model, "enc_to_dec_proj"):
tf_model(tf_model.dummy_inputs)
tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel)
tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias)
with tempfile.TemporaryDirectory() as tmpdirname:
encoder_dir = os.path.join(tmpdirname, "encoder")
decoder_dir = os.path.join(tmpdirname, "decoder")
tf_model.encoder.save_pretrained(encoder_dir)
tf_model.decoder.save_pretrained(decoder_dir)
if hasattr(tf_model, "enc_to_dec_proj"):
enc_to_dec_proj_weight = torch.transpose(
torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0
)
enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy())
del _tf_model
del tf_model
gc.collect()
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True
)
# This is only for copying some specific attributes of this particular model.
model.config = config
if hasattr(model, "enc_to_dec_proj"):
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias
return model
# At the moment fast initialization is not supported for composite models # At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False): if kwargs.get("_fast_init", False):
logger.warning( logger.warning(
...@@ -275,7 +368,8 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -275,7 +368,8 @@ class EncoderDecoderModel(PreTrainedModel):
"Falling back to slow initialization..." "Falling back to slow initialization..."
) )
kwargs["_fast_init"] = False kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod
def from_encoder_decoder_pretrained( def from_encoder_decoder_pretrained(
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
""" Classes to support TF Encoder-Decoder architectures""" """ Classes to support TF Encoder-Decoder architectures"""
import gc
import os
import tempfile import tempfile
import warnings import warnings
from typing import Optional from typing import Optional
...@@ -291,24 +293,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -291,24 +293,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" r"""
Initializing *TFEncoderDecoderModel* from a pytorch checkpoint is not supported currently.
If there are only pytorch checkpoints for a particular encoder-decoder model, a workaround is:
```python
>>> # a workaround to load from pytorch checkpoint
>>> from transformers import EncoderDecoderModel, TFEncoderDecoderModel
>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
```
Example: Example:
```python ```python
...@@ -319,12 +303,42 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -319,12 +303,42 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
if from_pt: if from_pt:
raise ValueError( import torch
"Initializing `TFEncoderDecoderModel` from a pytorch checkpoint is not supported currently. Use a"
" tensorflow checkpoint instead. If only the pytorch checkpoints are available, create the encoder and" from transformers import EncoderDecoderModel
" decoder models separately, and use them to initialize `TFEncoderDecoderModel`. Check"
" `TFEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details." # a workaround to load from pytorch checkpoint
) _model = EncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = _model.config
with tempfile.TemporaryDirectory() as tmpdirname:
encoder_dir = os.path.join(tmpdirname, "encoder")
decoder_dir = os.path.join(tmpdirname, "decoder")
_model.encoder.save_pretrained(encoder_dir)
_model.decoder.save_pretrained(decoder_dir)
if hasattr(_model, "enc_to_dec_proj"):
enc_to_dec_proj_kernel = tf.transpose(
tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0)
)
enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy())
del _model
gc.collect()
torch.cuda.empty_cache()
model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
model.config = config
if hasattr(model, "enc_to_dec_proj"):
model(model.dummy_inputs)
model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel)
model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias)
return model
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
""" Classes to support TF Vision-Encoder-Text-Decoder architectures""" """ Classes to support TF Vision-Encoder-Text-Decoder architectures"""
import gc
import os
import tempfile import tempfile
import warnings import warnings
from typing import Optional from typing import Optional
...@@ -294,22 +296,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -294,22 +296,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" r"""
Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently.
If there are only pytorch checkpoints for a particular encoder-decoder model, a workaround is:
```python
>>> # a workaround to load from pytorch checkpoint
>>> _model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
```
Example: Example:
```python ```python
...@@ -337,12 +323,42 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -337,12 +323,42 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
if from_pt: if from_pt:
raise ValueError( import torch
"Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently. Use"
" a tensorflow checkpoint instead. If only the pytorch checkpoints are available, create the encoder" from transformers import VisionEncoderDecoderModel
" and decoder models separately, and use them to initialize `TFVisionEncoderDecoderModel`. Check"
" `TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details." # a workaround to load from pytorch checkpoint
) _model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = _model.config
with tempfile.TemporaryDirectory() as tmpdirname:
encoder_dir = os.path.join(tmpdirname, "encoder")
decoder_dir = os.path.join(tmpdirname, "decoder")
_model.encoder.save_pretrained(encoder_dir)
_model.decoder.save_pretrained(decoder_dir)
if hasattr(_model, "enc_to_dec_proj"):
enc_to_dec_proj_kernel = tf.transpose(
tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0)
)
enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy())
del _model
gc.collect()
torch.cuda.empty_cache()
model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
model.config = config
if hasattr(model, "enc_to_dec_proj"):
model(model.dummy_inputs)
model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel)
model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias)
return model
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
...@@ -450,7 +466,8 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -450,7 +466,8 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly # Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model.
# See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313
if kwargs_encoder.get("from_pt", None): if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"] del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
...@@ -492,7 +509,8 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -492,7 +509,8 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly # Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model.
# See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313
if kwargs_decoder.get("from_pt", None): if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"] del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
""" Classes to support Vision-Encoder-Text-Decoder architectures""" """ Classes to support Vision-Encoder-Text-Decoder architectures"""
import gc
import os
import tempfile
from typing import Optional from typing import Optional
import torch import torch
...@@ -240,7 +243,115 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -240,7 +243,115 @@ class VisionEncoderDecoderModel(PreTrainedModel):
return self.decoder.set_output_embeddings(new_embeddings) return self.decoder.set_output_embeddings(new_embeddings)
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Example:
```python
>>> from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
>>> from PIL import Image
>>> import requests
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> decoder_tokenizer = GPT2Tokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> img = Image.open(requests.get(url, stream=True).raw)
>>> pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values # Batch size 1
>>> output_ids = model.generate(
... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True
... ).sequences
>>> preds = decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
>>> preds = [pred.strip() for pred in preds]
>>> assert preds == ["a cat laying on top of a couch next to another cat"]
```"""
from_tf = kwargs.pop("from_tf", False)
if from_tf:
from transformers import TFVisionEncoderDecoderModel
# a workaround to load from tensorflow checkpoint
# Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get
# extended before saving those components. For example, The name of `_tf_model.encoder.vit` is
# `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The
# [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`,
# which should not occur when we want to save the components alone.
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFVisionEncoderDecoderModel.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
config = _tf_model.config
# Using `tf_model` instead
encoder = _tf_model.encoder.__class__(_tf_model.config.encoder)
decoder = _tf_model.decoder.__class__(_tf_model.config.decoder)
# Make sure models are built
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
# Get the variable correspondence between `_tf_model` and `encoder` and `decoder`
encoder_variables = {}
for v in encoder.trainable_variables + encoder.non_trainable_variables:
encoder_variables["/".join(v.name.split("/")[1:])] = v
decoder_variables = {}
for v in decoder.trainable_variables + decoder.non_trainable_variables:
decoder_variables["/".join(v.name.split("/")[1:])] = v
_encoder_variables = {}
for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables:
_encoder_variables["/".join(v.name.split("/")[2:])] = v
_decoder_variables = {}
for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables:
_decoder_variables["/".join(v.name.split("/")[2:])] = v
# assign weight values to `encoder` and `decoder` from `_tf_model`
for name, v in encoder_variables.items():
v.assign(_encoder_variables[name])
for name, v in decoder_variables.items():
v.assign(_decoder_variables[name])
tf_model = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
# Deal with `enc_to_dec_proj`
if hasattr(_tf_model, "enc_to_dec_proj"):
tf_model(tf_model.dummy_inputs)
tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel)
tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias)
with tempfile.TemporaryDirectory() as tmpdirname:
encoder_dir = os.path.join(tmpdirname, "encoder")
decoder_dir = os.path.join(tmpdirname, "decoder")
tf_model.encoder.save_pretrained(encoder_dir)
tf_model.decoder.save_pretrained(decoder_dir)
if hasattr(tf_model, "enc_to_dec_proj"):
enc_to_dec_proj_weight = torch.transpose(
torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0
)
enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy())
del _tf_model
del tf_model
gc.collect()
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True
)
# This is only for copying some specific attributes of this particular model.
model.config = config
if hasattr(model, "enc_to_dec_proj"):
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias
return model
# At the moment fast initialization is not supported for composite models # At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False): if kwargs.get("_fast_init", False):
logger.warning( logger.warning(
...@@ -248,7 +359,8 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -248,7 +359,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
"Falling back to slow initialization..." "Falling back to slow initialization..."
) )
kwargs["_fast_init"] = False kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod
def from_encoder_decoder_pretrained( def from_encoder_decoder_pretrained(
......
...@@ -523,15 +523,9 @@ class TFEncoderDecoderMixin: ...@@ -523,15 +523,9 @@ class TFEncoderDecoderMixin:
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# PT -> TF # PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
pt_model.encoder.save_pretrained(encoder_tmp_dirname) tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model_loaded = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
...@@ -546,15 +540,9 @@ class TFEncoderDecoderMixin: ...@@ -546,15 +540,9 @@ class TFEncoderDecoderMixin:
pt_model = EncoderDecoderModel(encoder_decoder_config) pt_model = EncoderDecoderModel(encoder_decoder_config)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
pt_model.encoder.save_pretrained(encoder_tmp_dirname) tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
...@@ -567,33 +555,13 @@ class TFEncoderDecoderMixin: ...@@ -567,33 +555,13 @@ class TFEncoderDecoderMixin:
# TODO: A generalizable way to determine this attribute # TODO: A generalizable way to determine this attribute
encoder_decoder_config.output_attentions = True encoder_decoder_config.output_attentions = True
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving tf_model = TFEncoderDecoderModel(encoder_decoder_config)
# the encoder/decoder models. # Make sure model is built before saving
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see tf_model(**tf_inputs_dict)
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
_tf_model(**tf_inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
# Make sure models are built
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
tf_model.config = encoder_decoder_config
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname)
tf_model.encoder.save_pretrained(encoder_tmp_dirname) pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
pt_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
)
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
...@@ -696,20 +664,11 @@ class TFEncoderDecoderMixin: ...@@ -696,20 +664,11 @@ class TFEncoderDecoderMixin:
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels) self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels) self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# check `enc_to_dec_proj` work as expected # check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
# self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
# self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2 decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size) self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
model = TFEncoderDecoderModel(encoder_decoder_config) self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
model(tf_inputs_dict)
def test_model_save_load_from_pretrained(self): def test_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model() model_2 = self.get_pretrained_model()
......
...@@ -456,20 +456,13 @@ class TFVisionEncoderDecoderMixin: ...@@ -456,20 +456,13 @@ class TFVisionEncoderDecoderMixin:
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# PT -> TF # PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
pt_model.encoder.save_pretrained(encoder_tmp_dirname) tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model_loaded = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict): def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict):
"""EncoderDecoderModel requires special way to cross load (PT -> TF)"""
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Output all for aggressive testing # Output all for aggressive testing
...@@ -479,20 +472,13 @@ class TFVisionEncoderDecoderMixin: ...@@ -479,20 +472,13 @@ class TFVisionEncoderDecoderMixin:
pt_model = VisionEncoderDecoderModel(encoder_decoder_config) pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
pt_model.encoder.save_pretrained(encoder_tmp_dirname) tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict): def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):
"""EncoderDecoderModel requires special way to cross load (TF -> PT)"""
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Output all for aggressive testing # Output all for aggressive testing
...@@ -500,33 +486,13 @@ class TFVisionEncoderDecoderMixin: ...@@ -500,33 +486,13 @@ class TFVisionEncoderDecoderMixin:
# TODO: A generalizable way to determine this attribute # TODO: A generalizable way to determine this attribute
encoder_decoder_config.output_attentions = True encoder_decoder_config.output_attentions = True
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
# the encoder/decoder models. # Make sure model is built before saving
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see tf_model(**tf_inputs_dict)
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
_tf_model(**tf_inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
# Make sure models are built
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
tf_model.config = encoder_decoder_config
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname)
tf_model.encoder.save_pretrained(encoder_tmp_dirname) pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
pt_model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
)
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
...@@ -624,20 +590,11 @@ class TFVisionEncoderDecoderMixin: ...@@ -624,20 +590,11 @@ class TFVisionEncoderDecoderMixin:
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels) self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels) self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# check `enc_to_dec_proj` work as expected # check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
# self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
# self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2 decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size) self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
model = TFVisionEncoderDecoderModel(encoder_decoder_config) self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
model(tf_inputs_dict)
@slow @slow
def test_real_model_save_load_from_pretrained(self): def test_real_model_save_load_from_pretrained(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment