Unverified Commit 23ab0b69 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples/flax] clip style image-text training example (#12491)

* clip style example

* fix post init

* add requirements

* update readme, few small fixes
parent 89a8739f
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Vision-Text dual encoder model training examples
> Note: This example is experimental and might not give the best possible results
The following example showcases how to train a CLIP like vision-text dual encoder model
using a pre-trained vision and text encoder using the JAX/Flax backend.
Such a model can be used for natural language image search and potentially zero-shot image classification.
The model is inspired by the [CLIP](https://openai.com/blog/clip/) approach, introduced by Alec Radford et al.
The idea is to train a vision encoder and a text encoder jointly to project the representation of images and their
captions into the same embedding space, such that the caption embeddings are located near the embeddings
of the images they describe.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
In this example we will use the vision model from [CLIP](https://huggingface.co/models?filter=clip)
as the image encoder and [`roberta-base`](https://huggingface.co/roberta-base) as the text encoder.
Note that one can also use the [ViT](https://huggingface.co/models?filter=vit) model as image encoder and any other BERT or ROBERTa model as text encoder.
To train the model on languages other than English one should choose a text encoder trained on the desired
language and a image-text dataset in that language. One such dataset is [WIT](https://github.com/google-research-datasets/wit).
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"clip-roberta-base"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create clip-roberta-base
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/clip-roberta-base
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd clip-roberta-base
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_hybrid_clip.py`.
```bash
export MODEL_DIR="./clip-roberta-base
ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py
```
## Prepare the dataset
We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
### Download and extract the data.
It consists of two compressed folders: one with images, and the other—with associated image captions. Note that the compressed images folder is 13GB in size.
```bash
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
wget http://images.cocodataset.org/zips/train2014.zip
unzip annotations_trainval2014.zip
unzip train2014.zip
mkdir coco_dataset
mv train2014 coco_dataset/
mv annotations coco_dataset/
```
### Prepare dataset files and split the dataset.
```python
import json
import collections
images_dir = "coco_dataset/train2014"
annotation_file = "coco_dataset/annotations/captions_train2014.json"
with open(annotation_file, "r") as f:
annotations = json.load(f)["annotations"]
image_path_to_caption = collections.defaultdict(list)
for element in annotations:
caption = f"{element['caption'].lower().rstrip('.')}"
image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
image_path_to_caption[image_path].append(caption)
lines = []
for image_path, captions in image_path_to_caption.items():
lines.append(json.dumps({"image_path": image_path, "captions": captions}))
train_lines = lines[:-8000]
valid_line = lines[-8000:]
with open("coco_dataset/train_dataset.json", "w") as f:
f.write("\n".join(train_lines))
with open("coco_dataset/valid_dataset.json", "w") as f:
f.write("\n".join(valid_line))
```
> Note: The data loading and processing part of this script can still be improved for maximum performance. In particular one should decode the images beforehand and use those instead decoding them each time. If the dataset is small or if you have huge disk space the you could also pre-process all the dataset beforehand and then use it.
## Train the model
Next we can run the example script to train the model:
```bash
python run_clip.py \
--output_dir ${MODEL_DIR} \
--text_model_name_or_path="roberta-base" \
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
--tokenizer_name="roberta-base" \
--train_file="coco_dataset/train_dataset.json" \
--validation_file="coco_dataset/validation_dataset.json" \
--do_train --do_eval \
--num_train_epochs="40" --max_seq_length 96 \
--per_device_train_batch_size="64" \
--per_device_eval_batch_size="64" \
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
--overwrite_output_dir \
--preprocessing_num_workers 32 \
--push_to_hub
```
This should finish in ~1h50 mins with min validation loss 2.43. Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/RUNPYd1yRgSD5kZSb9hDig/#scalars)
\ No newline at end of file
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class HybridCLIPConfig(PretrainedConfig):
r"""
:class:`HybridCLIPConfig` is the configuration class to store the configuration of a
:class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
defining the text model and vision model configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
text_config_dict (:obj:`dict`):
Dictionary of configuration options that defines text model config.
vision_config_dict (:obj:`dict`):
Dictionary of configuration options that defines vison model config.
projection_dim (:obj:`int`, `optional`, defaults to 512):
Dimentionality of text and vision projection layers.
kwargs (`optional`):
Dictionary of keyword arguments.
"""
model_type = "hybrid-clip"
is_composition = True
def __init__(self, text_config_dict, vision_config_dict, projection_dim=512, **kwargs):
super().__init__(**kwargs)
if text_config_dict is None:
raise ValueError("`text_config_dict` can not be `None`.")
if vision_config_dict is None:
raise ValueError("`vision_config_dict` can not be `None`.")
text_model_type = text_config_dict.pop("model_type")
vision_model_type = vision_config_dict.pop("model_type")
from transformers import AutoConfig
self.text_config = AutoConfig.for_model(text_model_type, **text_config_dict)
if vision_model_type == "clip":
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict).vision_config
else:
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict)
self.projection_dim = projection_dim
self.initializer_factor = 1.0
@classmethod
def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
r"""
Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
vision model configuration.
Returns:
:class:`HybridCLIPConfig`: An instance of a configuration object
"""
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default
:meth:`~transformers.PretrainedConfig.to_dict`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from configuration_hybrid_clip import HybridCLIPConfig
from flax.core.frozen_dict import FrozenDict
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
from transformers.utils import logging
logger = logging.get_logger(__name__)
class FlaxHybridCLIPModule(nn.Module):
config: HybridCLIPConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
text_config = self.config.text_config
vision_config = self.config.vision_config
self.projection_dim = self.config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
self.text_model = text_module(text_config, dtype=self.dtype)
self.vision_model = vision_module(vision_config, dtype=self.dtype)
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
def __call__(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
# cosine similarity as logits
logit_scale = jnp.exp(self.logit_scale)
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
logits_per_image = logits_per_text.T
if not return_dict:
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return FlaxCLIPOutput(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
class FlaxHybridCLIP(FlaxPreTrainedModel):
config: HybridCLIPConfig
module_class = FlaxHybridCLIPModule
def __init__(
self,
config: HybridCLIPConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
token_type_ids = jnp.ones_like(input_ids)
attention_mask = jnp.ones_like(input_ids)
pixel_values = jax.random.normal(rng, input_shape[1])
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
def __call__(
self,
input_ids,
pixel_values,
attention_mask=None,
position_ids=None,
token_type_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(pixel_values, dtype=jnp.float32),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
def get_text_features(
self,
input_ids,
attention_mask=None,
position_ids=None,
token_type_ids=None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
r"""
Args:
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of text model.
"""
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
text_outputs = module.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
deterministic=deterministic,
)
pooled_output = text_outputs[1]
text_features = module.text_projection(pooled_output)
return text_features
return self.module.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
not train,
method=_get_features,
rngs=rngs,
)
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
using :class:`~transformers.ImageFeatureExtractionMixin`. See
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of vision model.
"""
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, pixel_values, deterministic):
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
pooled_output = vision_outputs[1] # pooled_output
image_features = module.visual_projection(pooled_output)
return image_features
return self.module.apply(
{"params": self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,
rngs=rngs,
)
@classmethod
def from_text_vision_pretrained(
cls,
text_model_name_or_path: str = None,
vision_model_name_or_path: str = None,
*model_args,
**kwargs,
) -> FlaxPreTrainedModel:
kwargs_text = {
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
}
kwargs_vision = {
argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
}
# remove text, vision kwargs from kwargs
for key in kwargs_text.keys():
del kwargs["text_" + key]
for key in kwargs_vision.keys():
del kwargs["vision_" + key]
# Load and initialize the text and vision model
text_model = kwargs_text.pop("model", None)
if text_model is None:
assert (
text_model_name_or_path is not None
), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
from transformers import FlaxAutoModel
if "config" not in kwargs_text:
from transformers import AutoConfig
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
kwargs_text["config"] = text_config
text_model = FlaxAutoModel.from_pretrained(
text_model_name_or_path, *model_args, from_pt=True, **kwargs_text
)
vision_model = kwargs_vision.pop("model", None)
if vision_model is None:
assert (
vision_model_name_or_path is not None
), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
from transformers import FlaxAutoModel
if "config" not in kwargs_vision:
from transformers import AutoConfig
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
kwargs_vision["config"] = vision_config
vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
# instantiate config with corresponding kwargs
dtype = kwargs.pop("dtype", jnp.float32)
config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
# init model
model = cls(config, *model_args, dtype=dtype, **kwargs)
if vision_config.model_type == "clip":
model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
else:
model.params["vision_model"] = vision_model.params
model.params["text_model"] = text_model.params
return model
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.4
optax>=0.0.8
-f https://download.pytorch.org/whl/torch_stable.html
torch==1.9.0+cpu
-f https://download.pytorch.org/whl/torch_stable.html
torchvision==0.10.0+cpu
\ No newline at end of file
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training a CLIP like dual encoder models using text and vision encoders in the library.
The script can be used to train CLIP like models for languages other than english by using
a text encoder pre-trained in the desired language. Currently this script support the following vision
and text models:
Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
"""
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import torch
from torchvision.datasets import VisionDataset
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, shard, shard_prng_key
from modeling_hybrid_clip import FlaxHybridCLIP
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
logger = logging.getLogger(__name__)
# Cache the result
has_tensorboard = is_tensorboard_available()
if has_tensorboard:
try:
from flax.metrics.tensorboard import SummaryWriter
except ImportError as ie:
has_tensorboard = False
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
else:
print(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
text_model_name_or_path: str = field(
metadata={
"help": "The text model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
vision_model_name_or_path: str = field(
metadata={
"help": "The vision model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
)
max_seq_length: Optional[int] = field(
default=72,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
def __post_init__(self):
if self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension == "json", "`validation_file` should be a json file."
# We use torchvision for faster image pre-processing.
# We need to ensure faster processing speed as it can become a bottleneck on TPU
class Transform(torch.nn.Module):
def __init__(self, image_size):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
return x
class ImageTextDataset(VisionDataset):
"""
Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
Args:
root: (string): The root path where the dataset is stored
file_path: (string): Path to the file containing the image_paths and associated captions.
The expected format is jsonlines where each line is a json object containing to keys.
`image_path`: The path to the image.
`captions`: An `array` of captions.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root: str,
file_path: str,
captions_per_image=2,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms, transform, target_transform)
with open(file_path, "r") as f:
examples = [json.loads(line) for line in f.readlines()]
self.captions = []
self.image_paths = []
for example in examples:
self.captions.extend(example["captions"][:captions_per_image])
self.image_paths.extend([example["image_path"]] * captions_per_image)
def _load_image(self, idx: int):
path = self.image_paths[idx]
return read_image(path, mode=ImageReadMode.RGB)
def _load_target(self, idx):
return self.captions[idx]
def __getitem__(self, index: int):
image = self._load_image(index)
target = self._load_target(index)
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
return len(self.captions)
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.text_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
model = FlaxHybridCLIP.from_text_vision_pretrained(
model_args.text_model_name_or_path,
model_args.vision_model_name_or_path,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
)
config = model.config
# set seed for torch dataloaders
set_seed(training_args.seed)
# Initialize torchvision transforms and jit them for faster processing
preprocess = Transform(config.vision_config.image_size)
preprocess = torch.jit.script(preprocess)
# Initialize the image-text dataset
train_dataset = ImageTextDataset(
data_args.data_dir,
data_args.train_file,
captions_per_image=2,
transform=preprocess,
)
eval_dataset = ImageTextDataset(
data_args.data_dir,
data_args.validation_file,
captions_per_image=1,
transform=preprocess,
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
# Use collate function to tokenizer the text and convert the processed images to numpy
def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
captions = [example[1] for example in examples]
inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
batch = {
"pixel_values": pixel_values,
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
}
return batch
# Create data loaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
collate_fn=collate_fn,
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
shuffle=False,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
collate_fn=collate_fn,
)
# Enable tensorboard only on the master node
if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
def cross_entropy(logits, axis):
logprobs = jax.nn.log_softmax(logits, axis=axis)
nll = jnp.diag(logprobs)
ce = -jnp.mean(nll)
return ce
def clip_loss(similarity):
loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
return loss
# Define gradient update step fn
def train_step(state, batch):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = clip_loss(logits)
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
# Define eval fn
def eval_step(params, batch):
logits = model(**batch, params=params, train=False)[0]
loss = clip_loss(logits)
# summarize metrics
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
# Create parallel version of the train and eval step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch")
# Replicate the train state on each device
state = state.replicate()
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0
# Create sampling rng
rng, input_rng = jax.random.split(rng)
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
steps_per_epoch = len(train_dataset) // train_batch_size
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
# train
for batch in train_loader:
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_step_progress_bar.update(1)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
train_step_progress_bar.close()
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
eval_metrics = []
eval_steps = len(eval_dataset) // eval_batch_size
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
for batch in eval_loader:
# Model forward
batch = shard(batch)
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)
eval_step_progress_bar.update(1)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
epochs.write(desc)
epochs.desc = desc
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment