Commit c501623c authored by chenych's avatar chenych
Browse files

add vlmo

parent 4538607b
(The following contents are from [the ViLT repo](https://github.com/dandelin/ViLT/blob/master/DATA.md).)
# Dataset Preparation
We utilize seven datsets: Google Conceptual Captions (GCC), Stony Brook University Captions (SBU), Visual Genome (VG), COCO Captions (COCO), Flickr 30K Captions (F30K), Visual Question Answering v2 (VQAv2), and Natural Language for Visual Reasoning 2 (NLVR2).
We do not distribute datasets because of the license issue.
Please download the datasets by yourself.
We use `pyarrow` to serialize the datasets, conversion scripts are located in `vilt/utils/write_*.py`.
Please organize the datasets as follows and run `make_arrow` functions to convert the dataset to pyarrow binary file.
## GCC
https://ai.google.com/research/ConceptualCaptions/download
GCC provides tuples of image url and caption, note that a quite portion of the urls are unaccessible now.
Write your own download script and organize the dataset as following structure.
root
├── images_train
│ ├── 0000 # First four letters of image name
│ │ ├── 0000000 # Image Binary
│ │ ├── 0000001
│ │ └── ...
│ ├── 0001
│ │ ├── 0001000
│ │ ├── 0001001
│ │ └── ...
│ └── ...
├── images_val
│ ├── 0000
│ │ └── ...
│ └── ...
├── train_annot.json # List of (image_file_path, caption) tuple
└── val_annot.json # List of (image_file_path, caption) tuple
```python
from vlmo.utils.write_conceptual_caption import make_arrow
make_arrow(root, arrows_root)
```
## SBU
http://www.cs.virginia.edu/~vicente/sbucaptions/
Similar to GCC, SBU also provides tuples of image url and caption, and also a quite portion of the urls are unaccessible now.
Write your own download script and organize the dataset as following structure.
root
├── images_train
│ ├── 0000 # First four letters of image name
│ │ ├── 0000000 # Image Binary
│ │ ├── 0000001
│ │ └── ...
│ ├── 0001
│ │ ├── 0001000
│ │ ├── 0001001
│ │ └── ...
│ └── ...
└── annot.json # List of (image_file_path, caption) tuple
```python
from vlmo.utils.write_sbu import make_arrow
make_arrow(root, arrows_root)
```
## VG
http://visualgenome.org/api/v0/api_home.html
Download [image part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [image part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip) and [region descriptions](http://visualgenome.org/static/data/dataset/region_descriptions.json.zip)
root
├── images
│ ├── VG_100K
│ │ ├── 10.jpg
│ │ ├── 107899.jpg
│ │ └── ...
│ ├── VG_100K_2
│ │ ├── 1.jpg
│ │ ├── 100.jpg
│ │ └── ...
│ └── ...
└── annotations
└── region_descriptions.json
```python
from vlmo.utils.write_vg import make_arrow
make_arrow(root, arrows_root)
```
## COCO
https://cocodataset.org/#download
Download [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip) and [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip)
root
├── train2014
│ ├── COCO_train2014_000000000009.jpg
| └── ...
├── val2014
| ├── COCO_val2014_000000000042.jpg
| └── ...
└── karpathy
└── dataset_coco.json
```python
from vlmo.utils.write_coco_karpathy import make_arrow
make_arrow(root, arrows_root)
```
## F30K
http://bryanplummer.com/Flickr30kEntities/
Sign [flickr images request form](https://forms.illinois.edu/sec/229675) and download [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip)
root
├── flickr30k-images
│ ├── 1000092795.jpg
| └── ...
└── karpathy
└── dataset_flickr30k.json
```python
from vlmo.utils.write_f30k_karpathy import make_arrow
make_arrow(root, arrows_root)
```
## VQAv2
https://visualqa.org/download.html
Download COCO [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip), [2015 test images](http://images.cocodataset.org/zips/test2015.zip), annotations ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip)), and questions ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip), [test](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip))
root
├── train2014
│ ├── COCO_train2014_000000000009.jpg
| └── ...
├── val2014
| ├── COCO_val2014_000000000042.jpg
| └── ...
├── test2015
| ├── COCO_test2015_000000000001.jpg
| └── ...
├── v2_OpenEnded_mscoco_train2014_questions.json
├── v2_OpenEnded_mscoco_val2014_questions.json
├── v2_OpenEnded_mscoco_test2015_questions.json
├── v2_OpenEnded_mscoco_test-dev2015_questions.json
├── v2_mscoco_train2014_annotations.json
└── v2_mscoco_val2014_annotations.json
```python
from vlmo.utils.write_vqa import make_arrow
make_arrow(root, arrows_root)
```
## NLVR2
Clone the [repository](https://github.com/lil-lab/nlvr) and sign the [request form](https://goo.gl/forms/yS29stWnFWzrDBFH3) to download the images.
root
├── images/train
│ ├── 0
│ │ ├── train-10108-0-img0.png
│ │ └── ...
│ ├── 1
│ │ ├── train-10056-0-img0.png
│ │ └── ...
│ └── ...
├── dev
│ ├── dev-0-0-img0.png
| └── ...
├── test1
│ ├── test1-0-0-img0.png
| └── ...
├── nlvr
├── nlvr2
└── README.md
```python
from vlmo.utils.write_nlvr2 import make_arrow
make_arrow(root, arrows_root)
```
## WikiBK (Text only data)
```python
from vlmo.utils.write_wikibk import make_arrow
make_arrow(root, arrows_root)
```
# VLMo - General-purpose Multimodal Pre-training
Paper: [VLMo: Unified Vision-Language Pre-Training with Mixture-of-Modality-Experts](https://arxiv.org/abs/2111.02358).
Official PyTorch implementation and pre-trained models of VLMo.
- Dec, 2022: Code & model release.
- Sep, 2022: [**VLMo**](https://arxiv.org/pdf/2111.02358.pdf) was accepted by NeurIPS 2022.
- May 30th, 2022: new version of [**VLMo** paper on arXiv](https://arxiv.org/pdf/2111.02358.pdf).
- November 24th, 2021: **VLMo** Large (**single** model) as the new SOTA on the [VQA Challenge](https://eval.ai/web/challenges/challenge-page/830/leaderboard/2278)
- Nov 2021: release preprint in [arXiv](https://arxiv.org/abs/2111.02358)
## Pre-trained Models
We provide three VLMo weights pre-trained on COCO, VG, SBU and GCC. The models were pre-trained with 224x224 resolution.
- [`VLMo-base`](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16; #VL_FFN=2 (#parameters: 175M)
- [`VLMo-base_plus`](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=24; hidden=544; FFN factor=4x; #head=16; patch=16x16; #VL_FFN=3 (#parameters: 167M)
- [`VLMo-large`](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16; #VL_FFN=3 (#parameters: 562M)
## Setup
```
alias=`whoami | cut -d'.' -f2`; docker run -it --rm --runtime=nvidia --ipc=host --privileged -v /home/${alias}:/home/${alias} pytorch/pytorch:1.8.0-cuda11.1-cudnn8-devel bash
```
First, clone the repo and install required packages:
```
git clone https://github.com/microsoft/unilm.git
cd unilm/vlmo
pip install -r requirements.txt
```
## Dataset Preparation
We process the pre-training and fine-tuning data to the same format as in [ViLT](DATA.md).
## Pre-training
Replace `<ARROW_ROOT>` as your data dir in following commands.
### Step 1: Vision Pre-Training
Download the pre-trained model weight from [BEiT repo](https://github.com/microsoft/unilm/tree/master/beit).
### Step 2: Language Pre-Training (VLMo-Base)
```bash
# download from https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D
export INIT_CKPT=/path/to/save/beit_base_checkpoint
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_textmlm_base whole_word_masking=True step200k per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path=$INIT_CKPT log_dir=<YOUR_OUTPUT_PATH>
```
Or you can download our pre-trained ckpts for this stage:
- [`VLMo-base-stage2`](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_224_stage2.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)
- [`VLMo-base_plus-stage2`](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_224_stage2.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)
- [`VLMo-large-stage2`](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_224_stage2.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)
### Step 3: Vision-Language Pre-Training (VLMo-Base)
```bash
export INIT_CKPT=/path/to/save/last_stage_ckpt
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_mlm_itm_itc_base whole_word_masking=True step200k per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path=$INIT_CKPT log_dir=<YOUR_OUTPUT_PATH>
```
## Fine-Tuning on Downstream Tasks
## Commands
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> "<CONFIG_NAME>" per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<VLMo_WEIGHT>" log_dir=<YOUR_OUTPUT_PATH>
```
To reduce GPU memory cost, use [Deepspeed](https://pytorch-lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed-zero-stage-1) and [Activation Checkpoint](https://fairscale.readthedocs.io/en/stable/api/nn/checkpoint/checkpoint_activations.html).
## Configs
You can found "<CONFIG_NAME>" for each task as follows:
### VQAv2
| <CONFIG_NAME> | initialized checkpoint | finetuned weight | test-dev |
|---------------|:----------------------:|:----------------:|:-----------:|
|task_finetune_vqa_base_image480|[VLMo-base](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_480_vqa.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|76.6|
|task_finetune_vqa_base_plus_image480|[VLMo-base_plus](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_480_vqa.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|78.5|
|task_finetune_vqa_large_image480|[VLMo-large](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_480_vqa.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|79.9|
### NLVR2
| <CONFIG_NAME> | initialized checkpoint | finetuned weight | test-P |
|---------------|:----------------------:|:----------------:|:-----------:|
|task_finetune_nlvr2_base_image384|[VLMo-base](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_384_nlvr2.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|83.3|
|task_finetune_nlvr2_base_plus_image384|[VLMo-base_plus](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_384_nlvr2.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|85.1|
|task_finetune_nlvr2_large_image384|[VLMo-large](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_384_nlvr2.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|86.9|
### COCO
| <CONFIG_NAME> | initialized checkpoint | finetuned weight | TR@1 | IR@1 |
|---------------|:----------------------:|:----------------:|:-----------:|:---:|
|task_finetune_irtr_coco_base_image384|[VLMo-base](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_384_coco.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|74.8|57.2|
|task_finetune_irtr_coco_base_plus_image384|[VLMo-base_plus](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_384_coco.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|76.3|58.6|
|task_finetune_irtr_coco_large_image384|[VLMo-large](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_384_coco.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|78.2|60.6|
### F30K
| <CONFIG_NAME> | initialized checkpoint | finetuned weight | TR@1 | IR@1 |
|---------------|:----------------------:|:----------------:|:-----------:|:---:|
|task_finetune_irtr_f30k_base_image384|[VLMo-base_coco_finetuned](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_384_coco.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_patch16_384_f30k.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|92.3|79.3|
|task_finetune_irtr_f30k_base_plus_image384|[VLMo-base_plus](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_224.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_base_plus_patch16_384_f30k.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|93.2|81.8|
|task_finetune_irtr_f30k_large_image384|[VLMo-large_coco_finetuned](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_384_coco.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|[weight](https://conversationhub.blob.core.windows.net/beit-share-public/vlmo/vlmo_large_patch16_384_f30k.pt?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D)|95.3|84.5|
## Evaluation
To eval a finetuned model by appending `test_only=True` and set `load_path=` to the finetuned VLMo weight as follow:
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=1 "<CONFIG_NAME>" per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<Finetuned_VLMo_WEIGHT>" test_only=True
```
- For retrieval tasks, also set `get_recall_metric=True` in the command.
## Acknowledgement
This repository is built using the [ViLT](https://github.com/dandelin/ViLT) repository, [BEiT](https://github.com/microsoft/unilm/tree/master/beit) repository, [ALBEF](https://github.com/salesforce/ALBEF) and the [timm](https://github.com/rwightman/pytorch-image-models) library.
## Citation
If you find this repository useful, please consider citing our work:
```
@inproceedings{vlmo,
title={{VLMo}: Unified Vision-Language Pre-Training with Mixture-of-Modality-Experts},
author={Hangbo Bao and Wenhui Wang and Li Dong and Qiang Liu and Owais Khan Mohammed and Kriti Aggarwal and Subhojit Som and Songhao Piao and Furu Wei},
booktitle={Advances in Neural Information Processing Systems},
year={2022},
url={https://openreview.net/forum?id=bydKs84JEyw}
}
```
## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
### Contact Information
For help or issues using VLMo models, please submit a GitHub issue.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# gcc
# from vlmo.utils.write_conceptual_caption import make_arrow
# root = "/parastor/home/chenych/MML/data/gcc"
# arrows_root = "/parastor/home/chenych/MML/data/gcc/arrows_root"
# nlvr2
# from vlmo.utils.write_nlvr2 import make_arrow
# root = "/home/data/nlvr-master/"
# arrows_root = "/home/data/nlvr-master/images/arrows_root"
# coco
from vlmo.utils.write_coco_karpathy import make_arrow
root = "/home/data/coco2014/"
arrows_root = "/home/data/coco2014/"
if __name__ == "__main__":
make_arrow(root, arrows_root)
pytorch_lightning==1.5.5
transformers==4.8.1
Pillow==8.3.1
tqdm==4.53.0
ipdb==0.13.7
einops==0.3.0
pyarrow==2.0.0
sacred==0.8.2
pandas==1.1.5
timm==0.4.12
torchmetrics==0.7.3
torch==1.8.0
torchvision==0.9.0
fairscale==0.4.0
numpy
scipy
opencv-python
\ No newline at end of file
import os
import copy
import pytorch_lightning as pl
from vlmo.config import ex
from vlmo.modules import VLMo
from vlmo.datamodules.multitask_datamodule import MTDataModule
from pytorch_lightning.plugins import environments as pl_env
from pytorch_lightning.utilities.distributed import rank_zero_info
class OMPIClusterEnvironment(pl_env.ClusterEnvironment):
def __init__(self):
super().__init__()
# def creates_children(self) -> bool:
# # return True if the cluster is managed (you don't launch processes yourself)
# assert (
# "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ
# ) # this cluster is managed
# return True
@property
def creates_processes_externally(self):
return True
def world_size(self) -> int:
return int(os.environ["OMPI_COMM_WORLD_SIZE"])
def set_world_size(self, size: int):
pass
def global_rank(self) -> int:
return int(os.environ["OMPI_COMM_WORLD_RANK"])
def set_global_rank(self, rank: int):
pass
def local_rank(self) -> int:
return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
def node_rank(self) -> int:
if "NODE_RANK" in os.environ:
return int(os.environ["NODE_RANK"])
else:
return 0
def master_address(self) -> str:
return os.environ["MASTER_ADDR"]
def master_port(self) -> int:
return int(os.environ["MASTER_PORT"])
def get_cluster_plugin(num_gpus=1, num_nodes=1):
if num_nodes > 1 or (
num_nodes == 1 and "OMPI_COMM_WORLD_SIZE" in os.environ
):
rank_zero_info("ClusterPlugin: using OMPI Cluster Environment")
return OMPIClusterEnvironment()
if num_gpus >= 1:
rank_zero_info("ClusterPlugin: using Lightning Cluster Environment")
return pl_env.LightningEnvironment()
return None
@ex.automain
def main(_config):
_config = copy.deepcopy(_config)
pl.seed_everything(_config["seed"])
dm = MTDataModule(_config, dist=True)
model = VLMo(_config)
exp_name = f'{_config["exp_name"]}'
os.makedirs(_config["log_dir"], exist_ok=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=-1,
verbose=True,
monitor="val/the_metric",
mode="max",
save_last=True,
)
logger = pl.loggers.TensorBoardLogger(
_config["log_dir"],
name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
)
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback, lr_callback]
num_gpus = (
_config["num_gpus"]
if isinstance(_config["num_gpus"], int)
else len(_config["num_gpus"])
)
grad_steps = _config["batch_size"] // (
_config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
)
rank_zero_info("grad_steps: {}".format(grad_steps))
max_steps = _config["max_steps"] if _config["max_steps"] is not None else None
resume_ckpt = None
if _config["resume_during_training"]:
for index in range(100):
ckpt_path = os.path.join(_config["log_dir"], f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', "version_{}/checkpoints/last.ckpt".format(index))
if os.path.exists(ckpt_path):
resume_ckpt = ckpt_path
rank_zero_info("resume_ckpt: {}".format(resume_ckpt))
cluster_plugin = get_cluster_plugin(
_config["num_gpus"], _config["num_nodes"]
)
plugin_list = [cluster_plugin]
rank_zero_info("plugin_list: {}".format(plugin_list))
if _config["use_sharded_training"]:
rank_zero_info("Using ddp sharded")
distributed_strategy = "ddp_sharded"
else:
distributed_strategy = "ddp"
trainer = pl.Trainer(
gpus=_config["num_gpus"],
num_nodes=_config["num_nodes"],
precision=_config["precision"],
accelerator="gpu",
strategy=distributed_strategy,
benchmark=True,
deterministic=True,
max_epochs=_config["max_epoch"] if max_steps is None else 1000,
max_steps=max_steps,
callbacks=callbacks,
logger=logger,
# prepare_data_per_node=False,
replace_sampler_ddp=False,
accumulate_grad_batches=grad_steps,
log_every_n_steps=10,
flush_logs_every_n_steps=10,
resume_from_checkpoint=resume_ckpt,
weights_summary="top",
fast_dev_run=_config["fast_dev_run"],
val_check_interval=_config["val_check_interval"],
plugins=plugin_list,
)
if _config["loss_names"]["textmlm"] > 0:
for param in model.parameters():
param.requires_grad = False
for name, param in model.named_parameters():
for key in ["text_embeddings", "token_type_embeddings", "mlp_text", "norm2_text", "mlm_score", "relative_position_bias_table", "transformer.norm"]:
if key in name:
param.requires_grad = True
for name, param in model.named_parameters():
rank_zero_info("{}\t{}".format(name, param.requires_grad))
if not _config["test_only"]:
trainer.fit(model, datamodule=dm)
else:
trainer.test(model, datamodule=dm)
from setuptools import setup, find_packages
setup(
name="vlmo",
packages=find_packages(
exclude=[".dfc", ".vscode", "dataset", "notebooks", "result", "scripts"]
),
version="1.0.0",
license="MIT",
description="VLMo: Unified Vision-Language Pre-Training with Mixture-of-Modality-Experts",
author="Wenhui Wang",
author_email="wenwan@microsoft.com",
url="https://github.com/microsoft/unilm/tree/master/vlmo",
keywords=["vision and language pretraining"],
install_requires=["torch", "pytorch_lightning"],
)
#!/bin/bash/
export HSA_FORCE_FINE_GRAIN_PCIE=1
export USE_MIOPEN_BATCHNORM=1
export INIT_CKPT=./pretrained_models/vlmo_base_patch16_224.pt
python run.py with data_root=/public/home/chenych/datasets/coco2014/ num_gpus=4 num_nodes=1 "task_finetune_irtr_coco_base_image384" per_gpu_batchsize=64 load_path=$INIT_CKPT log_dir=./logs/
# python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> "<TASK>" per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<VLMo_WEIGHT>" log_dir=<YOUR_OUTPUT_PATH>
\ No newline at end of file
#!/bin/bash/
export HSA_FORCE_FINE_GRAIN_PCIE=1
export USE_MIOPEN_BATCHNORM=1
python run.py with data_root=/public/home/chenych/datasets/coco2014/ num_gpus=1 num_nodes=1 "task_finetune_irtr_coco_base_image384" per_gpu_batchsize=64 load_path="./pretrained_models/vlmo_base_patch16_384_coco.pt" test_only=True
# python run.py with data_root=<ARROW_ROOT> num_gpus=1 num_nodes=1 "<CONFIG_NAME>" per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<Finetuned_VLMo_WEIGHT>" test_only=True
from sacred import Experiment
ex = Experiment("VLMo")
def _loss_names(d):
ret = {
"itm": 0, # image-text matching loss
"itc": 0, # image-text contrastive loss
"mlm": 0, # masked language modeling loss
"textmlm": 0, # text-only masked language modeling
"vqa": 0,
"nlvr2": 0,
"irtr": 0, # retrieval task ft
}
ret.update(d)
return ret
@ex.config
def config():
exp_name = "vlmo"
seed = 1
datasets = ["coco", "vg", "sbu", "gcc"]
loss_names = _loss_names({"itm": 1, "itc": 1, "mlm": 1})
batch_size = 1024 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller.
# Image setting
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
image_size = 224
draw_false_image = 0
image_only = False
text_only = False
# Text Setting
vqav2_label_size = 3129
max_text_len = 40
max_text_len_of_initckpt = 196
tokenizer = "bert-base-uncased"
vocab_size = 30522
whole_word_masking = False
mlm_prob = 0.15
draw_false_text = 0
# Transformer Setting
model_arch = "vlmo_base_patch16"
drop_path_rate = 0.1
# Optimizer Setting
optim_type = "adamw"
learning_rate = 1e-4
weight_decay = 0.01
decay_power = 1
max_epoch = 100
max_steps = 200000
warmup_steps = 0.1
end_lr = 0
lr_mult = 1 # multiply lr for downstream heads
# Downstream Setting
get_recall_metric = False
get_recall_rerank_metric = False
k_test = 32
# PL Trainer Setting
resume_from = None
fast_dev_run = False
val_check_interval = 1.0
test_only = False
use_sharded_training = False
resume_during_training = False
# below params varies with the environment
data_root = ""
log_dir = "result"
per_gpu_batchsize = 4 # you should define this manually with per_gpu_batch_size=#
num_gpus = 1
num_nodes = 1
load_path = ""
num_workers = 8
precision = 16
# ----------------------- language pretraining config -----------------------
@ex.named_config
def task_textmlm_base():
exp_name = "textmlm_base"
datasets = ["wikibk"]
loss_names = _loss_names({"textmlm": 1})
batch_size = 1024
max_text_len = 196
learning_rate = 2e-4
whole_word_masking = True
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_textmlm_base_plus():
exp_name = "textmlm_base_plus"
datasets = ["wikibk"]
loss_names = _loss_names({"textmlm": 1})
batch_size = 1024
max_text_len = 196
learning_rate = 2e-4
whole_word_masking = True
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
model_arch = "vlmo_base_plus_patch16"
# ----------------------- vision-language pretraining config -----------------------
# Named configs for "task" which define datasets, loss_names and desired batch_size, warmup_steps, epochs, and exp_name
@ex.named_config
def task_mlm_itm_itc_base():
exp_name = "mlm_itm_itc_base"
datasets = ["coco", "vg", "sbu", "gcc"]
loss_names = _loss_names({"itm": 1, "mlm": 1, "itc": 1})
batch_size = 1024
whole_word_masking = True
learning_rate = 2e-4
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_mlm_itm_itc_base_plus():
exp_name = "mlm_itm_itc_base_plus"
datasets = ["coco", "vg", "sbu", "gcc"]
loss_names = _loss_names({"itm": 1, "mlm": 1, "itc": 1})
batch_size = 1024
whole_word_masking = True
learning_rate = 1e-4
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
model_arch = "vlmo_base_plus_patch16"
@ex.named_config
def task_mlm_itm_itc_large():
exp_name = "mlm_itm_itc_large"
datasets = ["coco", "vg", "sbu", "gcc"]
loss_names = _loss_names({"itm": 1, "mlm": 1, "itc": 1})
batch_size = 1024
whole_word_masking = True
learning_rate = 5e-5
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
model_arch = "vit_large_patch16_224"
# ----------------------- NLVR2 fine-tuning configs -----------------------
@ex.named_config
def task_finetune_nlvr2_base():
exp_name = "finetune_nlvr2_base"
datasets = ["nlvr2"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"nlvr2": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 5e-5
val_transform_keys = ["square_transform"]
use_sharded_training=False
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_finetune_nlvr2_base_plus():
exp_name = "finetune_nlvr2_base_plus"
datasets = ["nlvr2"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"nlvr2": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 3e-5
drop_path_rate = 0.2
val_transform_keys = ["square_transform"]
use_sharded_training=False
model_arch = "vlmo_base_plus_patch16"
@ex.named_config
def task_finetune_nlvr2_base_image384():
exp_name = "finetune_nlvr2_base_image384"
datasets = ["nlvr2"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"nlvr2": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 5e-5
val_transform_keys = ["square_transform"]
image_size = 384
use_sharded_training=False
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_finetune_nlvr2_base_plus_image384():
exp_name = "finetune_nlvr2_base_plus_image384"
datasets = ["nlvr2"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"nlvr2": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 3e-5
drop_path_rate = 0.2
val_transform_keys = ["square_transform"]
image_size = 384
use_sharded_training=False
model_arch = "vlmo_base_plus_patch16"
@ex.named_config
def task_finetune_nlvr2_large():
exp_name = "finetune_nlvr2_large"
datasets = ["nlvr2"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"nlvr2": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 3e-5
drop_path_rate = 0.15
val_transform_keys = ["square_transform"]
use_sharded_training=False
model_arch = "vlmo_large_patch16"
@ex.named_config
def task_finetune_nlvr2_large_image384():
exp_name = "finetune_nlvr2_large_image384"
datasets = ["nlvr2"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"nlvr2": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 3e-5
drop_path_rate = 0.15
val_transform_keys = ["square_transform"]
image_size = 384
use_sharded_training=False
model_arch = "vlmo_large_patch16"
# ----------------------- VQAv2 Fine-tuning configs -----------------------
@ex.named_config
def task_finetune_vqa_base_image480():
exp_name = "finetune_vqa_base_image480"
datasets = ["vqa"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"vqa": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 3e-5
drop_path_rate = 0.15
val_transform_keys = ["square_transform"]
lr_mult = 20
image_size = 480
use_sharded_training=False
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_finetune_vqa_base_plus_image480():
exp_name = "finetune_vqa_base_plus_image480"
datasets = ["vqa"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"vqa": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 3e-5
drop_path_rate = 0.15
val_transform_keys = ["square_transform"]
lr_mult = 20
image_size = 480
use_sharded_training=False
model_arch = "vlmo_base_plus_patch16"
@ex.named_config
def task_finetune_vqa_large_image480():
exp_name = "finetune_vqa_large_image480"
datasets = ["vqa"]
train_transform_keys = ["square_transform_randaug"]
loss_names = _loss_names({"vqa": 1})
batch_size = 128
max_epoch = 10
max_steps = None
warmup_steps = 0.1
learning_rate = 1.5e-5
drop_path_rate = 0.15
val_transform_keys = ["square_transform"]
lr_mult = 20
image_size = 480
use_sharded_training=False
model_arch = "vlmo_large_patch16"
# ----------------------- F30K IR/TR Fine-tuning configs -----------------------
@ex.named_config
def task_finetune_irtr_f30k_base():
exp_name = "finetune_irtr_f30k_base"
datasets = ["f30k"]
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
loss_names = _loss_names({"irtr": 1.0})
batch_size = 3072
max_epoch = 50
max_steps = 1500
warmup_steps = 150
get_recall_metric = True
learning_rate = 3e-5
drop_path_rate = 0.15
use_sharded_training=False
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_finetune_irtr_f30k_base_image384():
exp_name = "finetune_irtr_f30k_base_image384"
datasets = ["f30k"]
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
loss_names = _loss_names({"irtr": 1.0})
batch_size = 3072
max_epoch = 50
max_steps = 1500
warmup_steps = 150
get_recall_metric = True
learning_rate = 3e-5
drop_path_rate = 0.15
image_size = 384
use_sharded_training=False
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_finetune_irtr_f30k_base_plus_image384():
exp_name = "finetune_irtr_f30k_base_plus_image384"
datasets = ["f30k"]
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
loss_names = _loss_names({"irtr": 1.0})
batch_size = 3072
max_epoch = 50
max_steps = 1500
warmup_steps = 150
get_recall_metric = True
learning_rate = 3e-5
drop_path_rate = 0.2
image_size = 384
use_sharded_training=False
model_arch = "vlmo_base_plus_patch16"
@ex.named_config
def task_finetune_irtr_f30k_large_image384():
exp_name = "finetune_irtr_f30k_large_image384"
datasets = ["f30k"]
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
loss_names = _loss_names({"irtr": 1.0})
batch_size = 3072
max_epoch = 50
max_steps = 1500
warmup_steps = 150
get_recall_metric = True
learning_rate = 2e-5
drop_path_rate = 0.2
image_size = 384
use_sharded_training=False
model_arch = "vlmo_large_patch16"
# ----------------------- COCO IR/TR Fine-tuning configs -----------------------
@ex.named_config
def task_finetune_irtr_coco_base_image384():
exp_name = "finetune_irtr_coco_base_image384"
datasets = ["coco"]
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
loss_names = _loss_names({"irtr": 1.0})
batch_size = 3072
max_epoch = 50
max_steps = 3000
warmup_steps = 300
get_recall_metric = True
learning_rate = 3e-5
drop_path_rate = 0.2
image_size = 384
use_sharded_training=False
model_arch = "vlmo_base_patch16"
@ex.named_config
def task_finetune_irtr_coco_base_plus_image384():
exp_name = "finetune_irtr_coco_base_plus_image384"
datasets = ["coco"]
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
loss_names = _loss_names({"irtr": 1.0})
batch_size = 3072
max_epoch = 50
max_steps = 3000
warmup_steps = 300
get_recall_metric = True
learning_rate = 3e-5
drop_path_rate = 0.2
image_size = 384
use_sharded_training=False
model_arch = "vlmo_base_plus_patch16"
@ex.named_config
def task_finetune_irtr_coco_large_image384():
exp_name = "finetune_irtr_coco_large_image384"
datasets = ["coco"]
train_transform_keys = ["square_transform_randaug"]
val_transform_keys = ["square_transform"]
loss_names = _loss_names({"irtr": 1.0})
batch_size = 3072
max_epoch = 50
max_steps = 3000
warmup_steps = 300
get_recall_metric = True
learning_rate = 2e-5
drop_path_rate = 0.2
image_size = 384
use_sharded_training=False
model_arch = "vlmo_large_patch16"
# ----------------------- Other configs -----------------------
# Named configs for "etc" which are orthogonal to "env" and "task", need to be added at the end
@ex.named_config
def step1_5k():
max_epoch = 100
warmup_steps = 150
max_steps = 1500
@ex.named_config
def step3k():
max_epoch = 100
warmup_steps = 300
max_steps = 3000
@ex.named_config
def step200k():
max_epoch = 200
warmup_steps = 2500
max_steps = 200000
@ex.named_config
def step500k():
max_epoch = 500
warmup_steps = 2500
max_steps = 500000
\ No newline at end of file
from .vg_caption_datamodule import VisualGenomeCaptionDataModule
from .f30k_caption_karpathy_datamodule import F30KCaptionKarpathyDataModule
from .coco_caption_karpathy_datamodule import CocoCaptionKarpathyDataModule
from .conceptual_caption_datamodule import ConceptualCaptionDataModule
from .sbu_datamodule import SBUCaptionDataModule
from .wikibk_datamodule import WikibkDataModule
from .vqav2_datamodule import VQAv2DataModule
from .nlvr2_datamodule import NLVR2DataModule
_datamodules = {
"vg": VisualGenomeCaptionDataModule,
"f30k": F30KCaptionKarpathyDataModule,
"coco": CocoCaptionKarpathyDataModule,
"gcc": ConceptualCaptionDataModule,
"sbu": SBUCaptionDataModule,
"wikibk": WikibkDataModule,
"vqa": VQAv2DataModule,
"nlvr2": NLVR2DataModule,
}
from vlmo.datasets import CocoCaptionKarpathyDataset
from .datamodule_base import BaseDataModule
class CocoCaptionKarpathyDataModule(BaseDataModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def dataset_cls(self):
return CocoCaptionKarpathyDataset
@property
def dataset_cls_no_false(self):
return CocoCaptionKarpathyDataset
@property
def dataset_name(self):
return "coco"
from vlmo.datasets import ConceptualCaptionDataset
from .datamodule_base import BaseDataModule
class ConceptualCaptionDataModule(BaseDataModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def dataset_cls(self):
return ConceptualCaptionDataset
@property
def dataset_name(self):
return "gcc"
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from transformers import (
DataCollatorForLanguageModeling,
DataCollatorForWholeWordMask,
BertTokenizer,
)
def get_pretrained_tokenizer(from_pretrained):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
BertTokenizer.from_pretrained(
from_pretrained, do_lower_case="uncased" in from_pretrained
)
torch.distributed.barrier()
return BertTokenizer.from_pretrained(
from_pretrained, do_lower_case="uncased" in from_pretrained
)
class BaseDataModule(LightningDataModule):
def __init__(self, _config):
super().__init__()
self.data_dir = _config["data_root"]
self.num_workers = _config["num_workers"]
self.batch_size = _config["per_gpu_batchsize"]
self.eval_batch_size = self.batch_size
self.image_size = _config["image_size"]
self.max_text_len = _config["max_text_len"]
self.draw_false_image = _config["draw_false_image"]
self.draw_false_text = _config["draw_false_text"]
self.image_only = _config["image_only"]
self.text_only = _config["text_only"]
self.train_transform_keys = (
["default_train"]
if len(_config["train_transform_keys"]) == 0
else _config["train_transform_keys"]
)
self.val_transform_keys = (
["default_val"]
if len(_config["val_transform_keys"]) == 0
else _config["val_transform_keys"]
)
tokenizer = _config["tokenizer"]
self.tokenizer = get_pretrained_tokenizer(tokenizer)
self.vocab_size = self.tokenizer.vocab_size
collator = (
DataCollatorForWholeWordMask
if _config["whole_word_masking"]
else DataCollatorForLanguageModeling
)
self.mlm_collator = collator(
tokenizer=self.tokenizer, mlm=True, mlm_probability=_config["mlm_prob"]
)
self.setup_flag = False
@property
def dataset_cls(self):
raise NotImplementedError("return tuple of dataset class")
@property
def dataset_name(self):
raise NotImplementedError("return name of dataset")
def set_train_dataset(self):
self.train_dataset = self.dataset_cls(
self.data_dir,
self.train_transform_keys,
split="train",
image_size=self.image_size,
max_text_len=self.max_text_len,
draw_false_image=self.draw_false_image,
draw_false_text=self.draw_false_text,
image_only=self.image_only,
)
def set_val_dataset(self):
self.val_dataset = self.dataset_cls(
self.data_dir,
self.val_transform_keys,
split="val",
image_size=self.image_size,
max_text_len=self.max_text_len,
draw_false_image=self.draw_false_image,
draw_false_text=self.draw_false_text,
image_only=self.image_only,
)
if hasattr(self, "dataset_cls_no_false"):
self.val_dataset_no_false = self.dataset_cls_no_false(
self.data_dir,
self.val_transform_keys,
split="val",
image_size=self.image_size,
max_text_len=self.max_text_len,
draw_false_image=0,
draw_false_text=0,
image_only=self.image_only,
)
def make_no_false_val_dset(self, image_only=False):
return self.dataset_cls_no_false(
self.data_dir,
self.val_transform_keys,
split="val",
image_size=self.image_size,
max_text_len=self.max_text_len,
draw_false_image=0,
draw_false_text=0,
image_only=image_only,
)
def make_no_false_test_dset(self, image_only=False):
return self.dataset_cls_no_false(
self.data_dir,
self.val_transform_keys,
split="test",
image_size=self.image_size,
max_text_len=self.max_text_len,
draw_false_image=0,
draw_false_text=0,
image_only=image_only,
)
def set_test_dataset(self):
self.test_dataset = self.dataset_cls(
self.data_dir,
self.val_transform_keys,
split="test",
image_size=self.image_size,
max_text_len=self.max_text_len,
draw_false_image=self.draw_false_image,
draw_false_text=self.draw_false_text,
image_only=self.image_only,
)
def setup(self, stage):
if not self.setup_flag:
self.set_train_dataset()
self.set_val_dataset()
self.set_test_dataset()
self.train_dataset.tokenizer = self.tokenizer
self.val_dataset.tokenizer = self.tokenizer
self.test_dataset.tokenizer = self.tokenizer
self.setup_flag = True
def train_dataloader(self):
loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.train_dataset.collate,
)
return loader
def val_dataloader(self):
loader = DataLoader(
self.val_dataset,
batch_size=self.eval_batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.val_dataset.collate,
)
return loader
def test_dataloader(self):
loader = DataLoader(
self.test_dataset,
batch_size=self.eval_batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.test_dataset.collate,
)
return loader
from vlmo.datasets import F30KCaptionKarpathyDataset
from .datamodule_base import BaseDataModule
class F30KCaptionKarpathyDataModule(BaseDataModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def dataset_cls(self):
return F30KCaptionKarpathyDataset
@property
def dataset_cls_no_false(self):
return F30KCaptionKarpathyDataset
@property
def dataset_name(self):
return "f30k"
import functools
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data.distributed import DistributedSampler
from . import _datamodules
class MTDataModule(LightningDataModule):
def __init__(self, _config, dist=False):
datamodule_keys = _config["datasets"]
assert len(datamodule_keys) > 0
super().__init__()
self.dm_keys = datamodule_keys
self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys}
self.dms = [v for k, v in self.dm_dicts.items()]
self.batch_size = self.dms[0].batch_size
self.vocab_size = self.dms[0].vocab_size
self.num_workers = self.dms[0].num_workers
self.dist = dist
def prepare_data(self):
for dm in self.dms:
dm.prepare_data()
def setup(self, stage):
for dm in self.dms:
dm.setup(stage)
self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms])
self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms])
self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms])
self.tokenizer = self.dms[0].tokenizer
self.collate = functools.partial(
self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator,
)
if self.dist:
self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True)
self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False)
else:
self.train_sampler = None
self.val_sampler = None
self.test_sampler = None
def train_dataloader(self):
loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
sampler=self.train_sampler,
num_workers=self.num_workers,
collate_fn=self.collate,
)
return loader
def val_dataloader(self, batch_size=None):
loader = DataLoader(
self.val_dataset,
batch_size=batch_size if batch_size is not None else self.batch_size,
sampler=self.val_sampler,
num_workers=self.num_workers,
collate_fn=self.collate,
)
return loader
def test_dataloader(self):
loader = DataLoader(
self.test_dataset,
batch_size=self.batch_size,
sampler=self.test_sampler,
num_workers=self.num_workers,
collate_fn=self.collate,
)
return loader
from vlmo.datasets import NLVR2Dataset
from .datamodule_base import BaseDataModule
class NLVR2DataModule(BaseDataModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def dataset_cls(self):
return NLVR2Dataset
@property
def dataset_name(self):
return "nlvr2"
from vlmo.datasets import SBUCaptionDataset
from .datamodule_base import BaseDataModule
class SBUCaptionDataModule(BaseDataModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def dataset_cls(self):
return SBUCaptionDataset
@property
def dataset_name(self):
return "sbu"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment