"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c8b6052ff681e3ca8dab168dfd524b9fbbceb5bd"
Unverified Commit f5b0c1ec authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax] Fix hybrid clip (#12519)

* fix saving and loading

* update readme
parent 7d6285a9
...@@ -68,6 +68,34 @@ export MODEL_DIR="./clip-roberta-base ...@@ -68,6 +68,34 @@ export MODEL_DIR="./clip-roberta-base
ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py
``` ```
## How to use the `FlaxHybridCLIP` model:
The `FlaxHybridCLIP` class let's you load any text and vision encoder model to create a dual encoder.
Here is an example of how to load the model using pre-trained text and vision models.
```python
from modeling_hybrid_clip import FlaxHybridCLIP
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32")
# save the model
model.save_pretrained("bert-clip")
# load the saved model
model = FlaxHybridCLIP.from_pretrained("bert-clip")
```
If the checkpoints are in PyTorch then one could pass `text_from_pt=True` and `vision_from_pt=True`. This will load the model
PyTorch checkpoints convert them to flax and load the model.
```python
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32", text_from_pt=True, vision_from_pt=True)
```
This loads both the text and vision encoders using pre-trained weights, the projection layers are randomly
initialized except for CLIP's vision model. If you use CLIP to initialize the vision model then the vision projection weights are also
loaded using the pre-trained weights.
## Prepare the dataset ## Prepare the dataset
We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search. We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
...@@ -124,7 +152,7 @@ with open("coco_dataset/valid_dataset.json", "w") as f: ...@@ -124,7 +152,7 @@ with open("coco_dataset/valid_dataset.json", "w") as f:
Next we can run the example script to train the model: Next we can run the example script to train the model:
```bash ```bash
python run_clip.py \ python run_hybrid_clip.py \
--output_dir ${MODEL_DIR} \ --output_dir ${MODEL_DIR} \
--text_model_name_or_path="roberta-base" \ --text_model_name_or_path="roberta-base" \
--vision_model_name_or_path="openai/clip-vit-base-patch32" \ --vision_model_name_or_path="openai/clip-vit-base-patch32" \
......
...@@ -25,31 +25,58 @@ class HybridCLIPConfig(PretrainedConfig): ...@@ -25,31 +25,58 @@ class HybridCLIPConfig(PretrainedConfig):
Dimentionality of text and vision projection layers. Dimentionality of text and vision projection layers.
kwargs (`optional`): kwargs (`optional`):
Dictionary of keyword arguments. Dictionary of keyword arguments.
Examples::
>>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
>>> # Initializing a BERT and CLIP configuration
>>> config_text = BertConfig()
>>> config_vision = CLIPConfig()
>>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
>>> # Initializing a BERT and CLIPVision model
>>> model = EncoderDecoderModel(config=config)
>>> # Accessing the model configuration
>>> config_text = model.config.text_config
>>> config_vision = model.config.vision_config
>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')
>>> # loading model and config from pretrained folder
>>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
>>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
""" """
model_type = "hybrid-clip" model_type = "hybrid-clip"
is_composition = True is_composition = True
def __init__(self, text_config_dict, vision_config_dict, projection_dim=512, **kwargs): def __init__(self, projection_dim=512, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if text_config_dict is None: if "text_config" not in kwargs:
raise ValueError("`text_config_dict` can not be `None`.") raise ValueError("`text_config` can not be `None`.")
if "vision_config" not in kwargs:
raise ValueError("`vision_config` can not be `None`.")
if vision_config_dict is None: text_config = kwargs.pop("text_config")
raise ValueError("`vision_config_dict` can not be `None`.") vision_config = kwargs.pop("vision_config")
text_model_type = text_config_dict.pop("model_type") text_model_type = text_config.pop("model_type")
vision_model_type = vision_config_dict.pop("model_type") vision_model_type = vision_config.pop("model_type")
from transformers import AutoConfig from transformers import AutoConfig
self.text_config = AutoConfig.for_model(text_model_type, **text_config_dict) self.text_config = AutoConfig.for_model(text_model_type, **text_config)
if vision_model_type == "clip": if vision_model_type == "clip":
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict).vision_config self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
else: else:
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict) self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
self.projection_dim = projection_dim self.projection_dim = projection_dim
self.initializer_factor = 1.0 self.initializer_factor = 1.0
...@@ -64,7 +91,7 @@ class HybridCLIPConfig(PretrainedConfig): ...@@ -64,7 +91,7 @@ class HybridCLIPConfig(PretrainedConfig):
:class:`HybridCLIPConfig`: An instance of a configuration object :class:`HybridCLIPConfig`: An instance of a configuration object
""" """
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self): def to_dict(self):
""" """
......
...@@ -123,7 +123,7 @@ class FlaxHybridCLIPModule(nn.Module): ...@@ -123,7 +123,7 @@ class FlaxHybridCLIPModule(nn.Module):
class FlaxHybridCLIP(FlaxPreTrainedModel): class FlaxHybridCLIP(FlaxPreTrainedModel):
config: HybridCLIPConfig config_class = HybridCLIPConfig
module_class = FlaxHybridCLIPModule module_class = FlaxHybridCLIPModule
def __init__( def __init__(
...@@ -304,6 +304,58 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -304,6 +304,58 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
*model_args, *model_args,
**kwargs, **kwargs,
) -> FlaxPreTrainedModel: ) -> FlaxPreTrainedModel:
"""
Params:
text_model_name_or_path (:obj: `str`, `optional`):
Information necessary to initiate the text model. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
Information necessary to initiate the vision model. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
model_args (remaining positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`).
- To update the text configuration, use the prefix `text_` for each configuration parameter.
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
Example::
>>> from transformers import FlaxHybridCLIP
>>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
>>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
>>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
>>> # saving model after fine-tuning
>>> model.save_pretrained("./bert-clip")
>>> # load fine-tuned model
>>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
"""
kwargs_text = { kwargs_text = {
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
...@@ -333,9 +385,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -333,9 +385,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
text_config = AutoConfig.from_pretrained(text_model_name_or_path) text_config = AutoConfig.from_pretrained(text_model_name_or_path)
kwargs_text["config"] = text_config kwargs_text["config"] = text_config
text_model = FlaxAutoModel.from_pretrained( text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
text_model_name_or_path, *model_args, from_pt=True, **kwargs_text
)
vision_model = kwargs_vision.pop("model", None) vision_model = kwargs_vision.pop("model", None)
if vision_model is None: if vision_model is None:
......
...@@ -87,6 +87,10 @@ class ModelArguments: ...@@ -87,6 +87,10 @@ class ModelArguments:
"Don't set if you want to train a model from scratch." "Don't set if you want to train a model from scratch."
}, },
) )
from_pt: bool = field(
default=True,
metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
)
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
) )
...@@ -332,6 +336,8 @@ def main(): ...@@ -332,6 +336,8 @@ def main():
model_args.vision_model_name_or_path, model_args.vision_model_name_or_path,
seed=training_args.seed, seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype), dtype=getattr(jnp, model_args.dtype),
text_from_pt=model_args.from_pt,
vision_from_pt=model_args.from_pt,
) )
config = model.config config = model.config
# set seed for torch dataloaders # set seed for torch dataloaders
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment