Unverified Commit 0e708178 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Pix2Struct] Add support to resize embeddings (#22394)

* First draft

* Fix integration test

* Remove script

* Fix test and typos

* Fix one more test

* Skip tied embeddings test

* Remove line

* Address comments
parent f6b80a01
...@@ -35,17 +35,16 @@ class Pix2StructTextConfig(PretrainedConfig): ...@@ -35,17 +35,16 @@ class Pix2StructTextConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate
a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the `Pix2StructText` used by the configuration with the defaults will yield a similar configuration to that of the Pix2Struct text decoder used by
[base architectures](https://huggingface.co/google/pix2struct-textcaps-base). the [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 50244): vocab_size (`int`, *optional*, defaults to 50244):
Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`Pix2StructModel`]. represented by the `inputs_ids` passed when calling [`Pix2StructTextModel`].
hidden_size (`int`, *optional*, defaults to 768): hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer. Dimensionality of the encoder layers and the pooler layer.
d_kv (`int`, *optional*, defaults to 64): d_kv (`int`, *optional*, defaults to 64):
...@@ -83,10 +82,10 @@ class Pix2StructTextConfig(PretrainedConfig): ...@@ -83,10 +82,10 @@ class Pix2StructTextConfig(PretrainedConfig):
```python ```python
>>> from transformers import Pix2StructTextConfig, Pix2StructTextModel >>> from transformers import Pix2StructTextConfig, Pix2StructTextModel
>>> # Initializing a Pix2StructTextConfig with Salesforce/pix2struct-vqa-base style configuration >>> # Initializing a Pix2StructTextConfig with google/pix2struct-base style configuration
>>> configuration = Pix2StructTextConfig() >>> configuration = Pix2StructTextConfig()
>>> # Initializing a Pix2StructTextModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration >>> # Initializing a Pix2StructTextModel (with random weights) from the google/pix2struct-base style configuration
>>> model = Pix2StructTextModel(configuration) >>> model = Pix2StructTextModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
...@@ -118,6 +117,7 @@ class Pix2StructTextConfig(PretrainedConfig): ...@@ -118,6 +117,7 @@ class Pix2StructTextConfig(PretrainedConfig):
use_cache=False, use_cache=False,
pad_token_id=0, pad_token_id=0,
eos_token_id=1, eos_token_id=1,
tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -143,6 +143,7 @@ class Pix2StructTextConfig(PretrainedConfig): ...@@ -143,6 +143,7 @@ class Pix2StructTextConfig(PretrainedConfig):
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
...@@ -168,14 +169,13 @@ class Pix2StructTextConfig(PretrainedConfig): ...@@ -168,14 +169,13 @@ class Pix2StructTextConfig(PretrainedConfig):
class Pix2StructVisionConfig(PretrainedConfig): class Pix2StructVisionConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to
instantiate a PIX2STRUCT vision model according to the specified arguments, defining the model architecture. instantiate a Pix2Struct vision model according to the specified arguments, defining the model architecture.
Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base
[Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base) architecture. [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
hidden_size (`int`, *optional*, defaults to 768): hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer. Dimensionality of the encoder layers and the pooler layer.
...@@ -223,10 +223,10 @@ class Pix2StructVisionConfig(PretrainedConfig): ...@@ -223,10 +223,10 @@ class Pix2StructVisionConfig(PretrainedConfig):
```python ```python
>>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel >>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel
>>> # Initializing a Pix2StructVisionConfig with Salesforce/pix2struct-vqa-base style configuration >>> # Initializing a Pix2StructVisionConfig with google/pix2struct-base style configuration
>>> configuration = Pix2StructVisionConfig() >>> configuration = Pix2StructVisionConfig()
>>> # Initializing a Pix2StructVisionModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration >>> # Initializing a Pix2StructVisionModel (with random weights) from the google/pix2struct-base style configuration
>>> model = Pix2StructVisionModel(configuration) >>> model = Pix2StructVisionModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
...@@ -301,11 +301,11 @@ class Pix2StructVisionConfig(PretrainedConfig): ...@@ -301,11 +301,11 @@ class Pix2StructVisionConfig(PretrainedConfig):
class Pix2StructConfig(PretrainedConfig): class Pix2StructConfig(PretrainedConfig):
r""" r"""
[`Pix2StructConfig`] is the configuration class to store the configuration of a [`Pix2StructModel`]. It is used to [`Pix2StructConfig`] is the configuration class to store the configuration of a
instantiate a PIX2STRUCT model according to the specified arguments, defining the text model and vision model [`Pix2StructForConditionalGeneration`]. It is used to instantiate a Pix2Struct model according to the specified
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will
PIX2STRUCT-base [Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base) yield a similar configuration to that of the Pix2Struct-base
architecture. [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
...@@ -327,20 +327,20 @@ class Pix2StructConfig(PretrainedConfig): ...@@ -327,20 +327,20 @@ class Pix2StructConfig(PretrainedConfig):
Example: Example:
```python ```python
>>> from transformers import Pix2StructConfig, Pix2StructModel >>> from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration
>>> # Initializing a Pix2StructConfig with Salesforce/pix2struct-vqa-base style configuration >>> # Initializing a Pix2StructConfig with google/pix2struct-base style configuration
>>> configuration = Pix2StructConfig() >>> configuration = Pix2StructConfig()
>>> # Initializing a Pix2StructPModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration >>> # Initializing a Pix2StructForConditionalGeneration (with random weights) from the google/pix2struct-base style configuration
>>> model = Pix2StructModel(configuration) >>> model = Pix2StructForConditionalGeneration(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
>>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig >>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig
>>> # Initializing a PIX2STRUCTText and PIX2STRUCTVision configuration >>> # Initializing a Pix2Struct text and Pix2Struct vision configuration
>>> config_text = Pix2StructTextConfig() >>> config_text = Pix2StructTextConfig()
>>> config_vision = Pix2StructVisionConfig() >>> config_vision = Pix2StructVisionConfig()
......
...@@ -1369,6 +1369,12 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): ...@@ -1369,6 +1369,12 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings self.embed_tokens = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1626,12 +1632,25 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1626,12 +1632,25 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
self.post_init() self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.decoder.get_input_embeddings()
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.decoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.decoder.set_output_embeddings(new_embeddings)
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
model_embeds = self.decoder.resize_token_embeddings(new_num_tokens)
# update vocab size
self.config.text_config.vocab_size = new_num_tokens
return model_embeds
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Pix2Struct model. """ """ Testing suite for the PyTorch Pix2Struct model. """
import copy
import inspect import inspect
import os import os
import tempfile import tempfile
...@@ -396,7 +396,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -396,7 +396,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = True
test_attention_outputs = False test_attention_outputs = False
test_torchscript = False test_torchscript = False
...@@ -526,6 +526,105 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -526,6 +526,105 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.text_config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Decoder input ids should be clamped to the maximum size of the vocabulary
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
def test_resize_embeddings_untied(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test
if original_config.tie_word_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
# if no output embeddings -> leave test
if model.get_output_embeddings() is None:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Decoder input ids should be clamped to the maximum size of the vocabulary
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@unittest.skip(reason="Pix2Struct doesn't use tied weights")
def test_tied_model_weights_key_ignore(self):
pass
def _create_and_check_torchscript(self, config, inputs_dict): def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript: if not self.test_torchscript:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment