Commit a245fbd1 authored by chenych's avatar chenych
Browse files

.

parent c501623c
# [(BEiT-3) Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks](https://arxiv.org/abs/2208.10442)
Official PyTorch implementation and pretrained models of BEiT-3.
The code and pretrained models of **BEiT** can be found at [here](https://github.com/microsoft/unilm/tree/master/beit).
The code and pretrained models of **BEiT v2** can be found at [here](https://github.com/microsoft/unilm/tree/master/beit2).
- March, 2023: release [the code and pretrained models of **BEiT-3**](https://github.com/microsoft/unilm/tree/master/beit3)
- March, 2023: [**BEiT-3**](https://arxiv.org/abs/2208.10442) was accepted by **CVPR 2023**.
- Sept 2022: release [the code and pretrained models of **BEiT v2**](https://github.com/microsoft/unilm/tree/master/beit2)
- Aug 2022: release preprint [Image as a Foreign Language: BEiT Pretraining for All Vision and Vision-Language Tasks](https://arxiv.org/abs/2208.10442)
- Aug 2022: release preprint [BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers](https://arxiv.org/abs/2208.06366)
- June 2022: release preprint [VL-BEiT: Generative Vision-Language Pretraining](https://arxiv.org/abs/2206.01127)
- March, 2022: add [linear probe examples](https://github.com/microsoft/unilm/blob/master/beit/get_started_for_image_classification.md#example-linear-probe-on-imagenet)
- January, 2022: [**BEiT**](https://openreview.net/forum?id=p-BhZSz59o4) was accepted by **ICLR 2022 as Oral presentation** (54 out of 3391).
- August 2021: [**BEiT**](https://huggingface.co/transformers/master/model_doc/beit.html) is on [HuggingFace](https://github.com/huggingface/transformers)
- July 2021: BEiT-large achieves **[state-of-the-art results on ADE20K](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k) (a big jump to 57.0 mIoU) for semantic segmentation**.
- July 2021: BEiT-large achieves **state-of-the-art ImageNet top-1 accuracy (88.6%) under the setting without extra data other than ImageNet-22k**.
- July 2021: release [the code and pretrained models of **BEiT**](https://github.com/microsoft/unilm/tree/master/beit)
- June 2021: release preprint [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
## Pretrained models
We provide BEiT-3 weights pretrained on monomodal and multimodal data. Our large-size model outperforms previous large-size models across various vision-language and vision downstream tasks. The models were pretrained with 224x224 resolution.
### Tips
- For vision-language tasks that require deep fusion, we recommend using `BEiT3-base` and `BEiT3-large`.
- For image-text retrieval or vision tasks, using `BEiT3-base-itc` and `BEiT3-large-itc` usually achieve better performance.
### Download Checkpoints
1. Models pretrained on ImageNet-21k images, 160 GB text documents, and web-scale image-text pairs (collected from [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/), [English LAION-2B](https://laion.ai/blog/laion-5b/), [COYO-700M](https://github.com/kakaobrain/coyo-dataset), and CC15M).
- [`BEiT3-base`](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16; #parameters: 276M
- [`BEiT3-large`](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16; #parameters: 746M
2. Perform image-text contrastive intermediate tuning on `BEiT3-base` and `BEiT3-large`.
- [`BEiT3-base-itc`](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_itc_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16; #parameters: 222M
- [`BEiT3-large-itc`](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_itc_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16; #parameters: 674M
3. Add indomain image-text pairs (COCO and VG) to continue training `BEiT3-base` and `BEiT3-large` using masked data modeling. The indomain models achieve better performance on VQAv2 and NLVR2 tasks.
- [`BEiT3-base-indomain`](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16; #parameters: 276M
- [`BEiT3-large-indomain`](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16; #parameters: 746M
### Text Tokenizer
[beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
```
### Architecture
We use [Magneto](https://arxiv.org/abs/2210.06423) with decoupled Multiway Transformer as the backbone architecture. Magneto can have better training stability and obtain better performance across modalities (such as vision, and language). The implementation is based on the [torchscale](https://github.com/microsoft/torchscale/blob/main/torchscale/model/BEiT3.py) package.
## Setup
```
alias=`whoami | cut -d'.' -f2`; docker run -it --rm --runtime=nvidia --ipc=host --privileged -v /home/${alias}:/home/${alias} pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel bash
```
Clone the repo and install required packages:
```
git clone https://github.com/microsoft/unilm.git
cd unilm/beit3
pip install -r requirements.txt
```
## Fine-tuning on ImageNet-1k (Image Classification)
The detailed instructions can be found at [`get_started_for_image_classification.md`](get_started/get_started_for_image_classification.md). We only use vision-related parameters for image classification fine-tuning.
| initialized checkpoint | resolution | acc@1 | acc@5 | #params | weight |
|:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
| [beit3_base_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 85.4 | 97.6 | 87M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/in1k/beit3_base_patch16_224_in1k.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_base_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 85.4 | 97.6 | 87M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/in1k/beit3_base_indomain_patch16_224_in1k.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 87.6 | 98.3 | 305M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/in1k/beit3_large_patch16_224_in1k.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 87.5 | 98.3 | 305M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/in1k/beit3_large_indomain_patch16_224_in1k.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
## Fine-tuning on VQAv2 (Visual Question Answering)
The detailed instructions can be found at [`get_started_for_vqav2.md`](get_started/get_started_for_vqav2.md).
| initialized checkpoint | resolution | augmented data | test-dev | test-std | #params | weight |
|:----------------------------------------|:----------:|:-----:|:-----:|:-----:|:-------:|-------------------|
| [beit3_base_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | - | 77.65 | - | 228M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/vqa/beit3_base_patch16_480_vqa.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_base_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | - | 78.46 | - | 228M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/vqa/beit3_base_indomain_patch16_480_vqa.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | - | 81.85 | - | 683M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/vqa/beit3_large_patch16_480_vqa.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | - | 82.53 | - | 683M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/vqa/beit3_large_indomain_patch16_480_vqa.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 768x768 | VGQA | 82.97 | 83.03 | 684M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/vqa/beit3_large_indomain_patch16_768_vgqaaug_vqa.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
## Fine-tuning on NLVR2 (Visual Reasoning)
The detailed instructions can be found at [`get_started_for_nlvr2.md`](get_started/get_started_for_nlvr2.md).
| initialized checkpoint | resolution | dev | test-P | #params | weight |
|:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
| [beit3_base_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 83.6 | 84.4 | 226M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nlvr2/beit3_base_patch16_224_nlvr2.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_base_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 84.6 | 85.3 | 226M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nlvr2/beit3_base_indomain_patch16_224_nlvr2.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 88.5 | 89.4 | 681M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nlvr2/beit3_large_patch16_224_nlvr2.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 224x224 | 89.2 | 90.0 | 681M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nlvr2/beit3_large_indomain_patch16_224_nlvr2.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
## Fine-tuning on COCO Captioning and NoCaps (Image Captioning)
The detailed instructions can be found at [`get_started_for_image_captioning.md`](get_started/get_started_for_captioning.md).
### COCO Captioning
| initialized checkpoint | resolution | test CIDEr | #params | weight |
|:----------------------------------------|:----------:|:-----:|:-------:|-------------------|
| [beit3_base_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | 133.6 | 271M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/coco_captioning/beit3_base_patch16_480_coco_captioning.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_base_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | 135.0 | 271M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/coco_captioning/beit3_base_indomain_patch16_480_coco_captioning.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | 143.2 | 739M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/coco_captioning/beit3_large_patch16_480_coco_captioning.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
### NoCaps
| initialized checkpoint | resolution | val CIDEr | #params | weight |
|:----------------------------------------|:----------:|:-----:|:-------:|-------------------|
| [beit3_base_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | 104.4 | 271M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nocaps/beit3_base_patch16_480_nocaps.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_base_indomain_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_indomain_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | 105.6 | 271M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nocaps/beit3_base_indomain_patch16_480_nocaps.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 480x480 | 120.2 | 739M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nocaps/beit3_large_patch16_480_nocaps.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
## Fine-tuning on COCO and Flickr30k Retrieval (Image-Text Retrieval)
The detailed instructions can be found at [`get_started_for_retrieval.md`](get_started/get_started_for_retrieval.md).
### COCO Retrieval
| initialized checkpoint | resolution | IR@1 | TR@1 | #params | weight |
|:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
| [beit3_base_itc_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_itc_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 384x384 | 61.4 | 79.1 | 222M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/coco_retrieval/beit3_base_patch16_384_coco_retrieval.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_itc_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_itc_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 384x384 | 63.4 | 82.1 | 675M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/coco_retrieval/beit3_large_patch16_384_coco_retrieval.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
### Flickr30k Retrieval
| initialized checkpoint | resolution | IR@1 | TR@1 | #params | weight |
|:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
| [beit3_base_itc_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_base_itc_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 384x384 | 86.2 | 96.3 | 222M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/f30k_retrieval/beit3_base_patch16_384_f30k_retrieval.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
| [beit3_large_itc_patch16_224](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/pretraining/beit3_large_itc_patch16_224.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) | 384x384 | 88.1 | 97.2 | 675M | [link](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/f30k_retrieval/beit3_large_patch16_384_f30k_retrieval.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) |
## Citation
If you find this repository useful, please consider citing our work:
```
@inproceedings{beit3,
title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks},
author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
@article{beitv2,
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
year={2022},
eprint={2208.06366},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@inproceedings{beit,
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=p-BhZSz59o4}
}
```
## Acknowledgement
This repository is built using the [BEiT](https://github.com/microsoft/unilm/tree/master/beit), the [BEiTv2](https://github.com/microsoft/unilm/tree/master/beit2), the [CLIP](https://github.com/openai/CLIP), the [open_clip](https://github.com/mlfoundations/open_clip), the [Oscar](https://github.com/microsoft/Oscar), the [DeiT](https://github.com/facebookresearch/deit), the [Dino](https://github.com/facebookresearch/dino) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library.
## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
### Contact Information
For help or issues using BEiT-3 models, please submit a GitHub issue.
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
import os
import json
import random
import torch
import glob
from collections import defaultdict, Counter
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data.transforms import RandomResizedCropAndInterpolation
from timm.data import create_transform
import utils
from glossary import normalize_word
from randaug import RandomAugment
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self, data_path, split, transform,
tokenizer, num_max_bpe_tokens, task=None,
):
index_files = self.get_index_files(split, task=task)
self.tokenizer = tokenizer
self.num_max_bpe_tokens = num_max_bpe_tokens
self.data_path = data_path
items = []
self.index_files = index_files
offset = 0
for _index_file in index_files:
index_file = os.path.join(data_path, _index_file)
with open(index_file, mode="r", encoding="utf-8") as reader:
for line in reader:
data = json.loads(line)
items.append(data)
print("Load %d image-text pairs from %s. " % (len(items) - offset, index_file))
offset = len(items)
self.items = items
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
self.pad_token_id = tokenizer.pad_token_id
self.loader = default_loader
self.transform = transform
self.split = split
@staticmethod
def get_index_files(split):
raise NotImplementedError()
def _get_image(self, image_path: str):
image_path = os.path.join(self.data_path, image_path)
image = self.loader(image_path)
return self.transform(image)
def _get_text_segment(self, text_segment, max_len=None):
if isinstance(text_segment, str):
tokens = self.tokenizer.tokenize(text_segment)
else:
tokens = text_segment[:]
if len(tokens) == 0:
raise RuntimeError("The text segment should contains at least one tokens!")
if max_len is None:
max_len = self.num_max_bpe_tokens
if len(tokens) > max_len - 2:
tokens = tokens[:max_len - 2]
tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id]
num_tokens = len(tokens)
padding_mask = [0] * num_tokens + [1] * (max_len - num_tokens)
return tokens + [self.pad_token_id] * (max_len - num_tokens), padding_mask, num_tokens
def _get_image_text_example(self, index: int, data: dict):
item = self.items[index]
img_path = item["image_path"]
img = self._get_image(img_path)
data["image"] = img
text_segment = item["text_segment"]
language_tokens, padding_mask, _ = self._get_text_segment(text_segment)
data["language_tokens"] = language_tokens
data["padding_mask"] = padding_mask
def __getitem__(self, index: int):
data = dict()
self._get_image_text_example(index, data)
return data
def __len__(self) -> int:
return len(self.items)
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = '{' + "\n Number of items: %s," % self.__len__()
body += "\n data root = %s," % self.data_path
body += "\n split = %s," % self.split
body += "\n dataset index files = %s" % str(self.index_files)
body += "\n num max bpe tokens = %s" % self.num_max_bpe_tokens
body += "\n transforms = ["
for t in self.transform.transforms:
body += "\n %s" % str(t)
body += "\n ]"
body += "\n}"
return head + body
def _write_data_into_jsonl(items, jsonl_file):
with open(jsonl_file, mode="w", encoding="utf-8") as writer:
for data in items:
writer.write(json.dumps(data, indent=None))
writer.write('\n')
print("Write %s with %d items !" % (jsonl_file, len(items)))
def _make_retrieval_coco_karpathy_dataset_index(
data_path,
tokenizer,
split=("train", "restval"),
split_name="train",
):
coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json")
items = []
image_counter = set()
print("read %s" % coco_karpathy_split_json_file)
with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader:
data = json.loads(reader.read())
for item in data["images"]:
if item["split"] in split:
image_path = os.path.join(item["filepath"], item["filename"])
for sent in item["sentences"]:
tokens = tokenizer.tokenize(sent["raw"])
token_ids = tokenizer.convert_tokens_to_ids(tokens)
items.append({
"image_path": image_path,
"text_segment": token_ids,
"image_id": len(image_counter),
})
if image_path not in image_counter:
image_counter.add(image_path)
print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \
(len(image_counter), len(items), split_name))
index_file = os.path.join(data_path, "coco_retrieval.%s.jsonl" % split_name)
_write_data_into_jsonl(items, index_file)
pass
def _make_captioning_coco_karpathy_dataset_index(
data_path,
tokenizer,
split=("train", "restval"),
split_name="train",
):
coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json")
items = []
image_counter = set()
print("read %s" % coco_karpathy_split_json_file)
with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader:
data = json.loads(reader.read())
for item in data["images"]:
if item["split"] in split:
image_path = os.path.join(item["filepath"], item["filename"])
if item["split"] in ["train", "restval"]:
for sent in item["sentences"]:
tokens = tokenizer.tokenize(sent["raw"])
token_ids = tokenizer.convert_tokens_to_ids(tokens)
items.append({
"image_path": image_path,
"text_segment": token_ids,
"image_id": item["cocoid"],
})
else:
items.append({
"image_path": image_path,
"text_segment": None,
"image_id": item["cocoid"],
})
if image_path not in image_counter:
image_counter.add(image_path)
print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \
(len(image_counter), len(items), split_name))
index_file = os.path.join(data_path, "coco_captioning.%s.jsonl" % split_name)
_write_data_into_jsonl(items, index_file)
pass
def _make_nocaps_dataset_index(
data_path,
split="val",
):
if split == "val":
json_file = "nocaps_val_4500_captions.json"
elif split == "test":
json_file = "nocaps_test_image_info.json"
nocaps_split_json_file = os.path.join(data_path, json_file)
items = []
image_counter = set()
print("read %s" % nocaps_split_json_file)
with open(nocaps_split_json_file, mode="r", encoding="utf-8") as reader:
data = json.loads(reader.read())
for item in data["images"]:
image_path = os.path.join(split, item["file_name"])
items.append({
"image_path": image_path,
"text_segment": None,
"image_id": item["id"],
})
if image_path not in image_counter:
image_counter.add(image_path)
print("Find %d images and %d image-text pairs for nocaps dataset %s split !" % \
(len(image_counter), len(items), split))
index_file = os.path.join(data_path, "nocaps.%s.jsonl" % split)
_write_data_into_jsonl(items, index_file)
class NLVR2Dataset(BaseDataset):
@staticmethod
def get_index_files(split, task=None):
if split == "train":
return ("nlvr2.train.index.jsonl", )
elif split == "val":
return ("nlvr2.dev.index.jsonl", )
elif split == "test":
return ("nlvr2.test-P.index.jsonl", )
else:
raise RuntimeError("split %s is not found!" % split)
def __getitem__(self, index: int):
data = super().__getitem__(index)
item = self.items[index]
img_path = item["image2_path"]
img = self._get_image(img_path)
data["image2"] = img
data["label"] = self.items[index]["label"]
return data
@staticmethod
def __preprocess_json(preifx, json_file, tokenizer, index_file):
items = []
with open(json_file, mode="r", encoding="utf-8") as reader:
for line in reader:
data = json.loads(line)
path = os.path.join(preifx, str(data["directory"])) if "directory" in data else preifx
path = os.path.join(path, "-".join(data["identifier"].split("-")[:-1]))
tokens = tokenizer.tokenize(data["sentence"])
token_ids = tokenizer.convert_tokens_to_ids(tokens)
items.append({
"image_path": path + "-img0.png",
"image2_path": path + "-img1.png",
"text_segment": token_ids,
"label": 1 if data["label"] == "True" else 0,
"identifier": data["identifier"],
})
_write_data_into_jsonl(items, index_file)
@classmethod
def make_dataset_index(cls, data_path, tokenizer, nlvr_repo_path):
cls.__preprocess_json(
preifx="images/train", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/train.json"),
tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("train")[0]),
)
cls.__preprocess_json(
preifx="dev", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/dev.json"),
tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("val")[0]),
)
cls.__preprocess_json(
preifx="test1", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/test1.json"),
tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("test")[0]),
)
class ImageNetDataset(BaseDataset):
@staticmethod
def get_index_files(split, task=None):
if split == "train":
return ("imagenet.train.index.jsonl", )
elif split == "val":
return ("imagenet.val.index.jsonl", )
elif split == "test":
return ("imagenet.val.index.jsonl", )
else:
raise RuntimeError("split %s is not found!" % split)
def __getitem__(self, index: int):
data = dict()
item = self.items[index]
img_path = item["image_path"]
img = self._get_image(img_path)
data["image"] = img
data["label"] = item["label"]
return data
@staticmethod
def _find_classes(dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
@staticmethod
def _make_imagenet_index(data_path, index_path, data_path_prefix, class_to_idx, split):
items = []
index_file = os.path.join(index_path, f"imagenet.{split}.index.jsonl")
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(data_path, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
path = path.replace(data_path_prefix, "")
items.append({
"image_path": path,
"label": class_index,
})
_write_data_into_jsonl(items, index_file)
@classmethod
def make_dataset_index(cls, train_data_path, val_data_path, index_path):
data_path_prefix = train_data_path[:[x[0]==x[1] for x in zip(train_data_path, val_data_path)].index(0)]
classes, class_to_idx = cls._find_classes(train_data_path)
cls._make_imagenet_index(
data_path=train_data_path, index_path=index_path, data_path_prefix=data_path_prefix,
class_to_idx=class_to_idx, split="train",
)
cls._make_imagenet_index(
data_path=val_data_path, index_path=index_path, data_path_prefix=data_path_prefix,
class_to_idx=class_to_idx, split="val",
)
class VQAv2Dataset(BaseDataset):
def __init__(self, data_path, **kwargs):
super().__init__(data_path=data_path, **kwargs)
ans2label_file = os.path.join(data_path, "answer2label.txt")
ans2label = {}
label2ans = []
with open(ans2label_file, mode="r", encoding="utf-8") as reader:
for i, line in enumerate(reader):
data = json.loads(line)
ans = data["answer"]
label = data["label"]
label = int(label)
assert label == i
ans2label[ans] = i
label2ans.append(ans)
self.ans2label = ans2label
self.label2ans = label2ans
@staticmethod
def get_index_files(split, task=None):
if split == "train":
return ("vqa.train.jsonl", "vqa.trainable_val.jsonl")
elif split == "val":
return ("vqa.rest_val.jsonl", )
elif split == "test":
return ("vqa.test.jsonl", )
elif split == "test-dev":
return ("vqa.test-dev.jsonl", )
else:
raise RuntimeError("split %s is not found!" % split)
def __getitem__(self, index: int):
data = super().__getitem__(index)
if "labels" in self.items[index] and len(self.items[index]["labels"]) > 0:
labels = [0.] * len(self.label2ans)
for l, s in zip(self.items[index]["labels"], self.items[index]["scores"]):
labels[l] = s
data["labels"] = torch.FloatTensor(labels)
else:
data["qid"] = self.items[index]["qid"]
return data
@staticmethod
def get_score(occurences):
if occurences == 0:
return 0.0
elif occurences == 1:
return 0.3
elif occurences == 2:
return 0.6
elif occurences == 3:
return 0.9
else:
return 1.0
@classmethod
def make_dataset_index(cls, data_path, tokenizer, annotation_data_path):
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_train2014_questions.json"), "r") as fp:
questions_train2014 = json.load(fp)["questions"]
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_val2014_questions.json"), "r") as fp:
questions_val2014 = json.load(fp)["questions"]
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test2015_questions.json"), "r") as fp:
questions_test2015 = json.load(fp)["questions"]
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test-dev2015_questions.json"), "r") as fp:
questions_test_dev2015 = json.load(fp)["questions"]
with open(os.path.join(annotation_data_path, "v2_mscoco_train2014_annotations.json"), "r") as fp:
annotations_train2014 = json.load(fp)["annotations"]
with open(os.path.join(annotation_data_path, "v2_mscoco_val2014_annotations.json"), "r") as fp:
annotations_val2014 = json.load(fp)["annotations"]
annotations = dict()
for split, questions in zip(
["train", "val", "test", "test-dev"],
[questions_train2014, questions_val2014, questions_test2015, questions_test_dev2015],
):
_annot = defaultdict(dict)
for q in questions:
question_text = q["question"]
tokens = tokenizer.tokenize(question_text)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
assert q["question_id"] not in _annot[q["image_id"]]
_annot[q["image_id"]][q["question_id"]] = {
"question": question_text,
"token_ids": token_ids,
}
annotations[split] = _annot
all_major_answers = list()
for split, annots in zip(
["train", "val"], [annotations_train2014, annotations_val2014],
):
# _annot = annotations[split]
for q in annots:
all_major_answers.append(q["multiple_choice_answer"])
all_major_answers = [normalize_word(word) for word in all_major_answers]
counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9}
ans2label = {k: i for i, k in enumerate(counter.keys())}
label2ans = list(counter.keys())
for split, annots in zip(
["train", "val"], [annotations_train2014, annotations_val2014],
):
_annot = annotations[split]
for q in annots:
answers = q["answers"]
answer_count = {}
for answer in answers:
answer_ = answer["answer"]
answer_count[answer_] = answer_count.get(answer_, 0) + 1
labels = []
scores = []
for answer in answer_count:
if answer not in ans2label:
continue
labels.append(ans2label[answer])
score = cls.get_score(answer_count[answer])
scores.append(score)
assert "labels" not in _annot[q["image_id"]][q["question_id"]]
assert "question" in _annot[q["image_id"]][q["question_id"]]
_annot[q["image_id"]][q["question_id"]]["labels"] = labels
_annot[q["image_id"]][q["question_id"]]["scores"] = scores
for split in ["train", "val"]:
filtered_annot = dict()
for ik, iv in annotations[split].items():
new_q = dict()
for qk, qv in iv.items():
if len(qv["labels"]) != 0:
new_q[qk] = qv
if len(new_q) != 0:
filtered_annot[ik] = new_q
annotations[split] = filtered_annot
split2items = {}
for split in ["train", "val", "test", "test-dev"]:
annot = annotations[split]
split_name = {
"train": "train2014",
"val": "val2014",
"test": "test2015",
"test-dev": "test2015",
}[split]
paths = list(glob.glob(f"{data_path}/{split_name}/*.jpg"))
random.shuffle(paths)
annot_paths = [path for path in paths \
if int(path.split("/")[-1].split("_")[-1][:-4]) in annot]
if len(paths) == len(annot_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(len(paths), len(annot_paths), len(annot))
items = []
for path in annot_paths:
iid = int(path.split("/")[-1].split("_")[-1][:-4])
_annot = annotations[split][iid]
for qid in _annot:
q = _annot[qid]
if split in ["train", "val"]:
labels = q["labels"]
scores = q["scores"]
else:
labels, scores = [], []
items.append({
"image_path": os.path.join(split_name, path.split('/')[-1]),
"text_segment": q["token_ids"],
"labels": labels,
"scores": scores,
"qid": qid,
})
split2items[split] = items
_write_data_into_jsonl(items=items, jsonl_file=os.path.join(data_path, "vqa.%s.jsonl" % split))
# Following ViLT, we use 1000 images of the original val set as the final val set
val_image2items = defaultdict(list)
for item in split2items["val"]:
val_image2items[item["image_path"]].append(item)
print("Contains %d image and %d pairs for val set!" % (len(val_image2items), len(split2items["val"])))
val_images = list(val_image2items.keys())
random.shuffle(val_images)
trainable_val = []
rest_val = []
for i, image_id in enumerate(val_images):
if i < 1000:
rest_val += val_image2items[image_id]
else:
trainable_val += val_image2items[image_id]
_write_data_into_jsonl(items=trainable_val, jsonl_file=os.path.join(data_path, "vqa.trainable_val.jsonl"))
_write_data_into_jsonl(items=rest_val, jsonl_file=os.path.join(data_path, "vqa.rest_val.jsonl"))
with open(os.path.join(data_path, "answer2label.txt"), mode="w", encoding="utf-8") as writer:
for ans in ans2label:
to_json = {
"answer": ans,
"label": ans2label[ans]
}
writer.write("%s\n" % json.dumps(to_json))
class RetrievalDataset(BaseDataset):
@staticmethod
def get_index_files(split, task=None):
if split == "train":
return (f"{task}.train.jsonl", )
elif split == "val":
return (f"{task}.val.jsonl", )
elif split == "test":
return (f"{task}.test.jsonl", )
else:
raise RuntimeError("split %s is not found!" % split)
def __getitem__(self, index: int):
data = super().__getitem__(index)
data["image_id"] = self.items[index]["image_id"]
return data
@staticmethod
def make_flickr30k_dataset_index(data_path, tokenizer, karpathy_path):
with open(os.path.join(karpathy_path, "dataset_flickr30k.json"), "r") as reader:
captions = json.loads(reader.read())
captions = captions["images"]
split2items = defaultdict(list)
split2images = defaultdict(set)
for each_item in captions:
image_path = os.path.join("flickr30k-images", each_item["filename"])
split = each_item["split"]
for text_segment in each_item["sentences"]:
tokens = tokenizer.tokenize(text_segment["raw"])
token_ids = tokenizer.convert_tokens_to_ids(tokens)
split2items[split].append({
"image_path": image_path,
"text_segment": token_ids,
"image_id": len(split2images[split]),
})
assert each_item["filename"] not in split2images[split]
split2images[split].add(each_item["filename"])
for split in split2items:
print("%d images and %d image-text pairs!" % (len(split2images[split]), len(split2items[split])))
_write_data_into_jsonl(split2items[split], os.path.join(data_path, "flickr30k.%s.jsonl" % split))
@staticmethod
def make_coco_dataset_index(data_path, tokenizer):
_make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train")
_make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val")
_make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test")
class CaptioningDataset(BaseDataset):
def __init__(self, data_path, split, transform,
tokenizer, num_max_bpe_tokens, task, mask_prob):
super().__init__(
data_path=data_path, split=split,
transform=transform, tokenizer=tokenizer,
num_max_bpe_tokens=num_max_bpe_tokens, task=task,
)
self.mask_token_id = tokenizer.mask_token_id
self.language_vocab_size = tokenizer.vocab_size
self.mask_prob = mask_prob
@staticmethod
def get_index_files(split, task=None):
if split == "train":
return ("coco_captioning.train.jsonl", )
elif split == "val":
return (f"{task}.val.jsonl", )
elif split == "test":
return (f"{task}.test.jsonl", )
else:
raise RuntimeError("split %s is not found!" % split)
def _get_mask_token(self, token):
p = random.random()
if p < 0.8:
return self.mask_token_id
elif p < 0.9:
return token
else:
return random.randint(3, self.language_vocab_size - 1)
def _masking_on_text_tokens(self, tokens, num_tokens, mask_prob):
bool_masked_pos = [0] * len(tokens)
to_mask = min(int(num_tokens * mask_prob + 0.5), num_tokens - 1)
to_mask = max(to_mask, 1)
num_masked_tokens = 0
while num_masked_tokens < to_mask:
i = random.randint(1, num_tokens - 1)
if bool_masked_pos[i] == 0:
bool_masked_pos[i] = 1
tokens[i] = self._get_mask_token(tokens[i])
num_masked_tokens += 1
return tokens, bool_masked_pos
def __getitem__(self, index: int):
data = dict()
item = self.items[index]
img_path = item["image_path"]
img = self._get_image(img_path)
data["image"] = img
data["image_id"] = item["image_id"]
text_segment = item["text_segment"]
if text_segment is not None:
language_tokens, padding_mask, num_tokens = self._get_text_segment(text_segment)
masked_tokens = language_tokens[:]
masked_tokens, language_masked_pos = \
self._masking_on_text_tokens(masked_tokens, num_tokens, self.mask_prob)
data["language_tokens"] = language_tokens
data["masked_tokens"] = masked_tokens
data["language_masked_pos"] = language_masked_pos
data["padding_mask"] = padding_mask
return data
@staticmethod
def make_coco_captioning_dataset_index(data_path, tokenizer):
_make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train")
_make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val")
_make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test")
@staticmethod
def make_nocaps_captioning_dataset_index(data_path):
_make_nocaps_dataset_index(data_path, split="val")
_make_nocaps_dataset_index(data_path, split="test")
task2dataset = {
"nlvr2": NLVR2Dataset,
"vqav2": VQAv2Dataset,
"flickr30k": RetrievalDataset,
"coco_retrieval": RetrievalDataset,
"coco_captioning": CaptioningDataset,
"nocaps": CaptioningDataset,
"imagenet": ImageNetDataset,
}
def create_dataloader(dataset, is_train, batch_size, num_workers, pin_mem, dist_eval=False):
if is_train or dist_eval:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
if not is_train and dist_eval and len(dataset) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train
)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
return torch.utils.data.DataLoader(
dataset, sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=is_train,
collate_fn=utils.merge_batch_tensors_by_dict_key,
)
def build_transform(is_train, args):
if args.task in ["imagenet"]:
return build_imagenet_transform(is_train, args)
if is_train:
t = [
RandomResizedCropAndInterpolation(args.input_size, scale=(0.5, 1.0), interpolation=args.train_interpolation),
transforms.RandomHorizontalFlip(),
]
if args.randaug:
t.append(
RandomAugment(
2, 7, isPIL=True,
augs=[
'Identity','AutoContrast','Equalize','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
]))
t += [
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
]
t = transforms.Compose(t)
else:
t = transforms.Compose([
transforms.Resize((args.input_size, args.input_size), interpolation=3),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
])
return t
def build_imagenet_transform(is_train, args):
resize_im = args.input_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
if args.crop_pct is None:
args.crop_pct = 1.0
size = int(args.input_size / args.crop_pct)
t.append(
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
def get_sentencepiece_model_for_beit3(args):
from transformers import XLMRobertaTokenizer
return XLMRobertaTokenizer(args.sentencepiece_model)
def create_dataset_by_split(args, split, is_train=True):
transform = build_transform(is_train=is_train, args=args)
dataset_class = task2dataset[args.task]
tokenizer = get_sentencepiece_model_for_beit3(args)
opt_kwargs = {}
if args.task in ["coco_captioning", "nocaps"]:
opt_kwargs["mask_prob"] = args.captioning_mask_prob
dataset = dataset_class(
data_path=args.data_path, split=split,
transform=transform, tokenizer=tokenizer,
num_max_bpe_tokens=args.num_max_bpe_tokens,
task=args.task, **opt_kwargs,
)
if is_train:
batch_size = args.batch_size
elif hasattr(args, "eval_batch_size") and args.eval_batch_size is not None:
batch_size = args.eval_batch_size
else:
batch_size = int(args.batch_size * 1.5)
return create_dataloader(
dataset, is_train=is_train, batch_size=batch_size,
num_workers=args.num_workers, pin_mem=args.pin_mem, dist_eval=args.dist_eval,
)
def create_downstream_dataset(args, is_eval=False):
if is_eval:
return create_dataset_by_split(args, split="test", is_train=False)
else:
return \
create_dataset_by_split(args, split="train", is_train=True), \
create_dataset_by_split(args, split="val", is_train=True)
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
import math
import sys
import json
from typing import Iterable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.utils import ModelEma
from timm.utils import accuracy, ModelEma
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from datasets import get_sentencepiece_model_for_beit3
import utils
class TaskHandler(object):
def __init__(self) -> None:
self.metric_logger = None
self.split = None
def train_batch(self, model, **kwargs):
raise NotImplementedError()
def eval_batch(self, model, **kwargs):
raise NotImplementedError()
def before_eval(self, metric_logger, data_loader, **kwargs):
self.metric_logger = metric_logger
self.split = data_loader.dataset.split
def after_eval(self, **kwargs):
raise NotImplementedError()
class NLVR2Handler(TaskHandler):
def __init__(self) -> None:
super().__init__()
self.criterion = torch.nn.CrossEntropyLoss()
def train_batch(self, model, image, image2, language_tokens, padding_mask, label):
logits = model(
image_a=image, image_b=image2,
text_description=language_tokens,
padding_mask=padding_mask)
acc = (logits.max(-1)[-1] == label).float().mean()
return {
"loss": self.criterion(input=logits, target=label),
"acc": acc,
}
def eval_batch(self, model, image, image2, language_tokens, padding_mask, label):
logits = model(
image_a=image, image_b=image2,
text_description=language_tokens,
padding_mask=padding_mask)
batch_size = language_tokens.shape[0]
acc = (logits.max(-1)[-1] == label).float().sum(0) * 100.0 / batch_size
self.metric_logger.meters['acc'].update(acc.item(), n=batch_size)
def after_eval(self, **kwargs):
print('* Acc {acc.global_avg:.3f}'.format(acc=self.metric_logger.acc))
return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc"
class ImageNetHandler(TaskHandler):
def __init__(self, args) -> None:
super().__init__()
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
# smoothing is handled with mixup label transform
self.criterion = SoftTargetCrossEntropy()
elif args.label_smoothing > 0.:
self.criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
else:
self.criterion = torch.nn.CrossEntropyLoss()
def train_batch(self, model, image, label):
logits = model(image=image)
return {
"loss": self.criterion(logits, label),
}
def eval_batch(self, model, image, label):
logits = model(image=image)
batch_size = image.shape[0]
acc1, acc5 = accuracy(logits, label, topk=(1, 5))
self.metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
self.metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
def after_eval(self, **kwargs):
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
.format(top1=self.metric_logger.acc1, top5=self.metric_logger.acc5))
return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc1"
class RetrievalHandler(TaskHandler):
def __init__(self) -> None:
super().__init__()
self.image_feats = []
self.text_feats = []
self.image_ids = []
self.metric_logger = None
def train_batch(self, model, image, language_tokens, padding_mask, image_id):
loss, vision_cls, language_cls = model(
image=image, text_description=language_tokens, padding_mask=padding_mask)
return {
"loss": loss,
}
def before_eval(self, metric_logger, **kwargs):
self.image_feats.clear()
self.text_feats.clear()
self.image_ids.clear()
self.metric_logger = metric_logger
def eval_batch(self, model, image, language_tokens, padding_mask, image_id):
vision_cls, _ = model(image=image, only_infer=True)
_, language_cls = model(
text_description=language_tokens, padding_mask=padding_mask, only_infer=True)
self.image_feats.append(vision_cls.clone())
self.text_feats.append(language_cls.clone())
self.image_ids.append(image_id.clone())
def after_eval(self, **kwargs):
image_feats = {}
for feats, ids in zip(self.image_feats, self.image_ids):
for i, _idx in enumerate(ids):
idx = _idx.item()
if idx not in image_feats:
image_feats[idx] = feats[i]
tiids = torch.cat(self.image_ids, dim=0)
iids = []
sorted_tensors = []
for key in sorted(image_feats.keys()):
sorted_tensors.append(image_feats[key].view(1, -1))
iids.append(key)
image_cls_feats = torch.cat(sorted_tensors, dim=0)
text_cls_feats = torch.cat(self.text_feats, dim=0)
scores = image_cls_feats @ text_cls_feats.t()
iids = torch.LongTensor(iids).to(scores.device)
print("scores: {}".format(scores.size()))
print("iids: {}".format(iids.size()))
print("tiids: {}".format(tiids.size()))
topk10 = scores.topk(10, dim=1)
topk5 = scores.topk(5, dim=1)
topk1 = scores.topk(1, dim=1)
topk10_iids = tiids[topk10.indices]
topk5_iids = tiids[topk5.indices]
topk1_iids = tiids[topk1.indices]
tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
topk10 = scores.topk(10, dim=0)
topk5 = scores.topk(5, dim=0)
topk1 = scores.topk(1, dim=0)
topk10_iids = iids[topk10.indices]
topk5_iids = iids[topk5.indices]
topk1_iids = iids[topk1.indices]
ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean()
ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean()
ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean()
eval_result = {
"tr_r10": tr_r10.item() * 100.0,
"tr_r5": tr_r5.item() * 100.0,
"tr_r1": tr_r1.item() * 100.0,
"ir_r10": ir_r10.item() * 100.0,
"ir_r5": ir_r5.item() * 100.0,
"ir_r1": ir_r1.item() * 100.0,
"average_score": 100.0 * (tr_r1 + tr_r5 + tr_r10 + ir_r1 + ir_r5 + ir_r10).item() / 6.0,
}
print('* Eval result = %s' % json.dumps(eval_result))
return eval_result, "average_score"
class VQAHandler(TaskHandler):
def __init__(self) -> None:
super().__init__()
self.predictions = []
self.criterion = nn.BCEWithLogitsLoss(reduction='mean')
self.label2ans = None
def train_batch(self, model, image, language_tokens, padding_mask, labels):
logits = model(
image=image, question=language_tokens,
padding_mask=padding_mask)
return {
"loss": self.criterion(input=logits.float(), target=labels.float()) * labels.shape[1],
}
def before_eval(self, metric_logger, data_loader, **kwargs):
self.predictions.clear()
self.metric_logger = metric_logger
self.label2ans = data_loader.dataset.label2ans
def eval_batch(self, model, image, language_tokens, padding_mask, labels=None, qid=None):
logits = model(
image=image, question=language_tokens,
padding_mask=padding_mask)
batch_size = language_tokens.shape[0]
if labels is not None:
scores = utils.VQAScore()(logits, labels) * 100.0
self.metric_logger.meters['score'].update(scores.item(), n=batch_size)
else:
_, preds = logits.max(-1)
for image_id, pred in zip(qid, preds):
self.predictions.append({
"question_id": image_id.item(),
"answer": self.label2ans[pred.item()],
})
def after_eval(self, **kwargs):
if len(self.predictions) == 0:
print('* Score {score.global_avg:.3f}'.format(score=self.metric_logger.score))
return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "score"
else:
return self.predictions, "prediction"
class CaptioningHandler(TaskHandler):
def __init__(self, args) -> None:
super().__init__()
self.predictions = []
self.criterion = utils.BertCaptioningLoss(args.label_smoothing, args.drop_worst_ratio, args.drop_worst_after)
self.tokenizer = get_sentencepiece_model_for_beit3(args)
self.num_beams = args.num_beams
self.max_len = args.num_max_bpe_tokens
self.length_penalty = args.length_penalty
self.vocab_size = args.vocab_size
def train_batch(self, model, image, language_tokens, masked_tokens, language_masked_pos, padding_mask, image_id, global_step):
logits, _ = model(
image=image, text_ids=masked_tokens, padding_mask=padding_mask, language_masked_pos=language_masked_pos, image_id=image_id)
masked_labels = language_tokens[language_masked_pos.bool()]
score = torch.max(logits, -1)[1].data == masked_labels
acc = torch.sum(score.float()) / torch.sum(language_masked_pos)
return {
"loss": self.criterion(logits, masked_labels, global_step),
"acc": acc
}
def before_eval(self, metric_logger, data_loader, **kwargs):
self.predictions.clear()
self.metric_logger = metric_logger
def eval_batch(self, model, image, image_id=None):
cur_len = 2
num_keep_best = 1
TOPN_PER_BEAM = 3
batch_size = image.size(0)
mask_id = self.tokenizer.mask_token_id
cls_id = self.tokenizer.cls_token_id
pad_id = self.tokenizer.pad_token_id
sep_id = self.tokenizer.sep_token_id
eos_token_ids = [sep_id]
cls_ids = torch.full(
(batch_size, 1), cls_id, dtype=torch.long, device=image.device
)
mask_ids = torch.full(
(batch_size, 1), mask_id, dtype=torch.long, device=image.device
)
cur_input_ids = torch.cat([cls_ids, mask_ids], dim=1)
tmp_ids = torch.full(
(batch_size, self.max_len-1), mask_id, dtype=torch.long, device=image.device
)
decoding_results = torch.cat([cls_ids, tmp_ids], dim=1)
# Expand input to num beams
cur_input_ids = cur_input_ids.unsqueeze(1).expand(batch_size, self.num_beams, cur_len)
cur_input_ids = cur_input_ids.contiguous().view(batch_size * self.num_beams, cur_len) # (batch_size * num_beams, cur_len)
decoding_results = decoding_results.unsqueeze(1).expand(batch_size, self.num_beams, self.max_len)
decoding_results = decoding_results.contiguous().view(batch_size * self.num_beams, self.max_len) # (batch_size * num_beams, cur_len)
image = image.unsqueeze(1).expand(batch_size, self.num_beams, image.size(-3), image.size(-2), image.size(-1))
image = image.contiguous().view(batch_size * self.num_beams, image.size(-3), image.size(-2), image.size(-1))
generated_hyps = [
utils.BeamHypotheses(
num_keep_best, self.max_len, length_penalty=self.length_penalty, early_stopping=False
) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, self.num_beams), dtype=torch.float, device=cur_input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# done sentences
done = [False for _ in range(batch_size)]
incremental_state = {}
while cur_len <= self.max_len:
next_token_idx = 1
padding_masks = torch.full(
cur_input_ids.shape, 0, dtype=torch.long, device=image.device
)
input_image = image
if cur_len != 2:
input_image = None
outputs, incremental_state_next = model(
image=input_image, text_ids=cur_input_ids, language_masked_pos=None,
padding_mask=padding_masks, text_len=cur_len, incremental_state=incremental_state)
incremental_state = incremental_state_next
# assert outputs.shape[1] == token_len
scores = outputs[:, next_token_idx, :] # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * self.num_beams, self.vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores = _scores.view(batch_size, self.num_beams * self.vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, TOPN_PER_BEAM * self.num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, TOPN_PER_BEAM * self.num_beams)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam = []
# for each sentence
for batch_ex in range(batch_size):
# if we are done with this sentence
done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item())
if done[batch_ex]:
next_batch_beam.extend([(0, pad_id, 0)] * self.num_beams) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]):
# get beam and word IDs
beam_id = idx // self.vocab_size
word_id = idx % self.vocab_size
# end of sentence, or next word
# if word_id.item() in eos_token_ids or cur_len + 1 == max_len:
if (word_id.item() in eos_token_ids and cur_len + 1 <= self.max_len) or (cur_len + 1 == self.max_len):
generated_hyps[batch_ex].add(
decoding_results[batch_ex * self.num_beams + beam_id, :cur_len].clone(), score.item()
)
else:
next_sent_beam.append((score, word_id, batch_ex * self.num_beams + beam_id))
# the beam for next step is full
if len(next_sent_beam) == self.num_beams:
break
# update next beam content
if cur_len + 1 == self.max_len:
assert len(next_sent_beam) == 0
else:
assert len(next_sent_beam) == self.num_beams
if len(next_sent_beam) == 0:
next_sent_beam = [(0, pad_id, 0)] * self.num_beams # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == self.num_beams * (batch_ex + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * self.num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = cur_input_ids.new([x[1] for x in next_batch_beam])
beam_idx = cur_input_ids.new([x[2] for x in next_batch_beam])
# re-order batch
cur_input_ids = cur_input_ids[beam_idx, :]
decoding_results = decoding_results[beam_idx, :]
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, beam_idx)
incremental_state[module][key] = result[:,:,:-1,:]
next_ids = torch.full(
(batch_size * self.num_beams, 1), mask_id, dtype=torch.long, device=image.device
)
cur_input_ids = torch.cat([beam_words.unsqueeze(1), next_ids], dim=1)
decoding_results[:, cur_len-1] = beam_words
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence
if all(done):
break
# select the best hypotheses
tgt_len = torch.ones(batch_size, num_keep_best, dtype=torch.long)
logprobs = torch.zeros(batch_size, num_keep_best,
dtype=torch.float).fill_(-1e5).to(cur_input_ids.device)
all_best = []
for i, hypotheses in enumerate(generated_hyps):
best = []
hyp_scores = torch.tensor([x[0] for x in hypotheses.hyp])
_, best_indices = torch.topk(hyp_scores,
min(num_keep_best, len(hyp_scores)), largest=True)
for best_idx, hyp_idx in enumerate(best_indices):
conf, best_hyp = hypotheses.hyp[hyp_idx]
best.append(best_hyp)
logprobs[i, best_idx] = conf
tgt_len[i, best_idx] = len(best_hyp) + 1 # +1 for the <EOS> symbol
all_best.append(best)
# generate target batch, pad to the same length
decoded = cur_input_ids.new(batch_size, num_keep_best, self.max_len).fill_(pad_id)
for batch_idx, best in enumerate(all_best):
for best_idx, hypo in enumerate(best):
decoded[batch_idx, best_idx, : tgt_len[batch_idx, best_idx] - 1] = hypo
decoded[batch_idx, best_idx, tgt_len[batch_idx, best_idx] - 1] = eos_token_ids[0]
captions = self.tokenizer.batch_decode(decoded.squeeze(1), skip_special_tokens=True)
for qid, pred in zip(image_id, captions):
self.predictions.append({
"image_id": qid.item(),
"caption": pred,
})
def after_eval(self, **kwargs):
return self.predictions, "prediction"
def get_handler(args):
if args.task == "nlvr2":
return NLVR2Handler()
elif args.task == "vqav2":
return VQAHandler()
elif args.task in ("flickr30k", "coco_retrieval"):
return RetrievalHandler()
elif args.task in ("coco_captioning", "nocaps"):
return CaptioningHandler(args)
elif args.task in ("imagenet"):
return ImageNetHandler(args)
else:
raise NotImplementedError("Sorry, %s is not support." % args.task)
def train_one_epoch(
model: torch.nn.Module, data_loader: Iterable,
optimizer: torch.optim.Optimizer, device: torch.device,
handler: TaskHandler, epoch: int, start_steps: int,
lr_schedule_values: list, loss_scaler, max_norm: float = 0,
update_freq: int = 1, model_ema: Optional[ModelEma] = None,
log_writer: Optional[utils.TensorboardLogger] = None,
task = None, mixup_fn=None,
):
model.train(True)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
if loss_scaler is None:
model.zero_grad()
model.micro_steps = 0
else:
optimizer.zero_grad()
for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
step = data_iter_step // update_freq
global_step = start_steps + step # global training iteration
# Update LR & WD for the first acc
if lr_schedule_values is not None and data_iter_step % update_freq == 0:
for i, param_group in enumerate(optimizer.param_groups):
if lr_schedule_values is not None:
param_group["lr"] = lr_schedule_values[global_step] * param_group["lr_scale"]
# put input data into cuda
for tensor_key in data.keys():
data[tensor_key] = data[tensor_key].to(device, non_blocking=True)
# print("input %s = %s" % (tensor_key, data[tensor_key]))
if loss_scaler is None and tensor_key.startswith("image"):
data[tensor_key] = data[tensor_key].half()
# mixup for imagenet finetuning
if mixup_fn is not None:
data["image"], data["label"] = mixup_fn(data["image"], data["label"])
if task in ["coco_captioning", "nocaps"]:
data["global_step"] = global_step
if loss_scaler is None:
results = handler.train_batch(model, **data)
else:
with torch.cuda.amp.autocast():
results = handler.train_batch(model, **data)
loss = results.pop("loss")
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
if loss_scaler is None:
loss /= update_freq
model.backward(loss)
model.step()
if (data_iter_step + 1) % update_freq == 0:
# model.zero_grad()
# Deepspeed will call step() & model.zero_grad() automatic
if model_ema is not None:
model_ema.update(model)
grad_norm = None
loss_scale_value = utils.get_loss_scale_for_deepspeed(model)
else:
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
loss /= update_freq
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=is_second_order,
update_grad=(data_iter_step + 1) % update_freq == 0)
if (data_iter_step + 1) % update_freq == 0:
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
loss_scale_value = loss_scaler.state_dict()["scale"]
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
metric_logger.update(loss_scale=loss_scale_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group["weight_decay"] > 0:
weight_decay_value = group["weight_decay"]
metric_logger.update(weight_decay=weight_decay_value)
metric_logger.update(grad_norm=grad_norm)
if log_writer is not None:
kwargs = {
"loss": loss_value,
}
for key in results:
kwargs[key] = results[key]
log_writer.update(head="train", **kwargs)
kwargs = {
"loss_scale": loss_scale_value,
"lr": max_lr,
"min_lr": min_lr,
"weight_decay": weight_decay_value,
"grad_norm": grad_norm,
}
log_writer.update(head="opt", **kwargs)
log_writer.set_step()
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device, handler):
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
handler.before_eval(metric_logger=metric_logger, data_loader=data_loader)
for data in metric_logger.log_every(data_loader, 10, header):
for tensor_key in data.keys():
data[tensor_key] = data[tensor_key].to(device, non_blocking=True)
with torch.cuda.amp.autocast():
handler.eval_batch(model=model, **data)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
return handler.after_eval()
# Fine-tuning BEiT-3 on Image Captioning
## COCO Captioning Setup
1. [Setup environment](../README.md#setup).
2. Download [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip) and [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip), then organize the dataset as following structure:
```
/path/to/your_data/
train2014/
COCO_train2014_000000000009.jpg
...
val2014/
COCO_val2014_000000000042.jpg
...
dataset_coco.json
```
We then generate the index json files using the following command. [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from datasets import CaptioningDataset
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
CaptioningDataset.make_coco_captioning_dataset_index(
data_path="/path/to/your_data",
tokenizer=tokenizer,
)
```
## NoCaps Setup
1. [Setup environment](README.md#setup).
2. Download [NoCaps val set](https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json), [NoCaps test set](https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json) and download imags using the urls in val and test json files, then organize the dataset as following structure:
```
/path/to/your_data/
val/
09c863d76bcf6b00.jpg
...
test/
19dc6913830a0a21.jpg
...
nocaps_val_4500_captions.json
nocaps_test_image_info.json
```
We then generate the index json files using the following command. [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from datasets import CaptioningDataset
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
CaptioningDataset.make_nocaps_captioning_dataset_index(
data_path="/path/to/your_data",
)
```
We use COCO captioning training set as the training data of NoCaps.
## Example: Fine-tuning BEiT-3 on Captioning
The BEiT-3 **base** model can be fine-tuned on captioning tasks using 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_base_patch16_480 \
--input_size 480 \
--task coco_captioning \
--batch_size 32 \
--layer_decay 1.0 \
--lr 4e-5 \
--randaug \
--epochs 10 \
--warmup_epochs 1 \
--drop_path 0.1 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.05 \
--seed 42 \
--save_ckpt_freq 5 \
--num_max_bpe_tokens 32 \
--captioning_mask_prob 0.7 \
--drop_worst_after 12000 \
--dist_eval \
--checkpoint_activations \
--enable_deepspeed
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
- `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
- `lr`: 4e-5 for COCO captioning and 1e-5 for NoCaps.
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
- `--checkpoint_activations`: using gradient checkpointing for saving GPU memory.
The BEiT-3 **large** model can be fine-tuned on captioning tasks using 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_large_patch16_480 \
--input_size 480 \
--task coco_captioning \
--batch_size 32 \
--layer_decay 1.0 \
--lr 8e-6 \
--randaug \
--epochs 10 \
--warmup_epochs 1 \
--drop_path 0.1 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.05 \
--seed 42 \
--save_ckpt_freq 5 \
--num_max_bpe_tokens 32 \
--captioning_mask_prob 0.7 \
--drop_worst_after 12000 \
--dist_eval \
--checkpoint_activations \
--enable_deepspeed
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
- `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
- `lr`: 8e-6 for COCO captioning and NoCaps.
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
- `--checkpoint_activations`: using gradient checkpointing for saving GPU memory.
## Example: Evaluate BEiT-3 Fine-tuned model on Captioning
- Get the prediction file of the fine-tuned BEiT3-base model on captioning with 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_base_patch16_480 \
--input_size 480 \
--task coco_captioning \
--batch_size 16 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_480_coco_captioning.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_prediction \
--eval \
--dist_eval
```
- `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
- `--finetune`: **beit3_base_patch16_480_coco_captioning.pth** for COCO captioning and **beit3_base_patch16_480_nocaps.pth** for NoCaps dataset.
- Get the prediction file of the fine-tuned BEiT3-large model on captioning with 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_large_patch16_480 \
--input_size 480 \
--task coco_captioning \
--batch_size 16 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_480_coco_captioning.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_prediction \
--eval \
--dist_eval
```
- `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
- `--finetune`: **beit3_large_patch16_480_coco_captioning.pth** for COCO captioning and **beit3_large_patch16_480_nocaps.pth** for NoCaps dataset.
Please then submit the prediction file in the `output_dir` to the [evaluation server](https://eval.ai/web/challenges/challenge-page/355/overview) to obtain the NoCaps val and test results.
# Fine-tuning BEiT-3 on ImageNet-1k (Image Classification)
## Setup
1. [Setup environment](../README.md#setup).
2. Download and extract ImageNet-1k from http://image-net.org/.
The directory structure is the standard layout of torchvision's [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder). The training and validation data are expected to be in the `train/` folder and `val/` folder, respectively:
```
/path/to/imagenet/
train/
class1/
img1.jpeg
class2/
img2.jpeg
val/
class1/
img3.jpeg
class/2
img4.jpeg
```
We then generate the index json files using the following command. [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from datasets import ImageNetDataset
ImageNetDataset.make_dataset_index(
train_data_path = "/path/to/your_data/train",
val_data_path = "/path/to/your_data/val",
index_path = "/path/to/your_data"
)
```
## Example: Fine-tuning BEiT-3 on ImageNet-1k (Image Classification)
The BEiT-3 **base** model can be finetuned on ImageNet-1k using 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_base_patch16_224 \
--task imagenet \
--batch_size 128 \
--layer_decay 0.65 \
--lr 7e-4 \
--update_freq 1 \
--epochs 50 \
--warmup_epochs 5 \
--drop_path 0.15 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.05 \
--seed 42 \
--save_ckpt_freq 5 \
--dist_eval \
--mixup 0.8 \
--cutmix 1.0 \
--enable_deepspeed
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*128*1 = 1024`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
The BEiT-3 **large** model can be finetuned on ImageNet-1k using a DGX box (8 V100-32GB):
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_large_patch16_224 \
--task imagenet \
--batch_size 128 \
--layer_decay 0.8 \
--lr 2e-4 \
--update_freq 1 \
--epochs 50 \
--warmup_epochs 5 \
--drop_path 0.25 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.05 \
--seed 42 \
--save_ckpt_freq 5 \
--dist_eval \
--mixup 0.8 \
--cutmix 1.0 \
--enable_deepspeed \
--checkpoint_activations
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*128 = 1024`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
- `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
## Example: Evaluate BEiT-3 Finetuned model on ImageNet-1k (Image Classification)
- Evaluate our fine-tuned BEiT3-base model on ImageNet val with a single GPU:
```bash
python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
--model beit3_base_patch16_224 \
--task imagenet \
--batch_size 128 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_224_in1k.pth \
--data_path /path/to/your_data \
--eval \
--dist_eval
```
Expected results:
```
* Acc@1 85.400 Acc@5 97.630
```
- Evaluate our fine-tuned BEiT3-large model on ImageNet val with a single GPU:
```bash
python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
--model beit3_large_patch16_224 \
--task imagenet \
--batch_size 128 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_224_in1k.pth \
--data_path /path/to/your_data \
--eval \
--dist_eval
```
Expected results:
```
* Acc@1 87.580 Acc@5 98.326
```
# Fine-tuning BEiT-3 on NLVR2 (Visual Reasoning)
## Setup
1. [Setup environment](../README.md#setup).
2. Clone the [repository](https://github.com/lil-lab/nlvr) and sign the [request form](https://goo.gl/forms/yS29stWnFWzrDBFH3) to download the images, then organize the dataset as following structure:
```
/path/to/your_data/
images/train/
0/train-11670-0-img0.png
...
dev/
dev-269-0-img0.png
...
test1/
test1-261-0-img0.png
...
nlvr/ (nlvr repo)
nlvr/
nlvr2/
```
We then generate the index json files using the following command. [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from datasets import NLVR2Dataset
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
NLVR2Dataset.make_dataset_index(
data_path="/path/to/your_data",
tokenizer=tokenizer,
nlvr_repo_path="/path/to/your_data/nlvr"
)
```
## Example: Fine-tuning BEiT-3 on NLVR2 (Visual Reasoning)
The BEiT-3 **base** model can be finetuned on NLVR2 using 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_base_patch16_224 \
--task nlvr2 \
--batch_size 32 \
--layer_decay 0.65 \
--lr 7e-4 \
--epochs 20 \
--warmup_epochs 5 \
--drop_path 0.2 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.2 \
--seed 42 \
--save_ckpt_freq 5 \
--enable_deepspeed
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
- `--lr`: 7e-4 for `BEiT3-base`, 5e-4 for `BEiT3-base-indomain`.
The BEiT-3 **large** model can be finetuned on NLVR2 using 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_large_patch16_224 \
--task nlvr2 \
--batch_size 32 \
--layer_decay 0.85 \
--lr 3e-4 \
--epochs 20 \
--warmup_epochs 5 \
--drop_path 0.2 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.2 \
--seed 42 \
--save_ckpt_freq 5 \
--enable_deepspeed \
--checkpoint_activations
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
- `--lr`: 3e-4 for `BEiT3-large`, 1e-4 for `BEiT3-large-indomain`.
- `--checkpoint_activations`: using gradient checkpointing for saving GPU memory.
## Example: Evaluate BEiT-3 Finetuned model on NLVR2 (Visual Reasoning)
- Get the result of our fine-tuned BEiT3-base model on NLVR2 test with 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_base_patch16_224 \
--task nlvr2 \
--batch_size 32 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_224_nlvr2.pth \
--data_path /path/to/your_data \
--eval \
--dist_eval
```
Expected results:
```
* Acc 84.386
```
- Get the result of our fine-tuned BEiT3-large model on NLVR2 test with 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_large_patch16_224 \
--task nlvr2 \
--batch_size 32 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_224_nlvr2.pth \
--data_path /path/to/your_data \
--eval \
--dist_eval
```
Expected results:
```
* Acc 89.437
```
# Fine-tuning BEiT-3 on Image-text Retrieval
## COCO Retrieval Setup
1. [Setup environment](../README.md#setup).
2. Download [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip) and [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip), then organize the dataset as following structure:
```
/path/to/your_data/
train2014/
COCO_train2014_000000000009.jpg
...
val2014/
COCO_val2014_000000000042.jpg
...
dataset_coco.json
```
We then generate the index json files using the following command. [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from datasets import RetrievalDataset
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
RetrievalDataset.make_coco_dataset_index(
data_path="/path/to/your_data",
tokenizer=tokenizer,
)
```
## Flickr30k Retrieval Setup
1. [Setup environment](README.md#setup).
2. Sign [flickr images request form](https://forms.illinois.edu/sec/229675) and download [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip), then organize the dataset as following structure:
```
/path/to/your_data/
flickr30k-images/
2923475135.jpg
...
dataset_flickr30k.json
```
We then generate the index json files using the following command. [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from datasets import RetrievalDataset
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
RetrievalDataset.make_flickr30k_dataset_index(
data_path="/path/to/your_data",
tokenizer=tokenizer,
karpathy_path="/path/to/your_data",
)
```
## Example: Fine-tuning BEiT-3 on Retrieval
The BEiT-3 **base** model can be finetuned on retrieval tasks using 16 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=16 run_beit3_finetuning.py \
--model beit3_base_patch16_384 \
--input_size 384 \
--task coco_retrieval \
--batch_size 192 \
--layer_decay 0.65 \
--lr 2e-4 \
--epochs 15 \
--warmup_epochs 3 \
--drop_path 0.2 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_itc_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.05 \
--seed 42 \
--save_ckpt_freq 5 \
--enable_deepspeed \
--checkpoint_activations
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `192*16 = 3072`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
- `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
- `--lr`: 2e-4 for COCO retrieval, 1e-4 for Flickr30k retrieval
- `--epochs`: 15 for COCO retrieval, 20 for Flickr30k retrieval
- `--warmup_epochs`: 3 for COCO retrieval, 5 for Flickr30k retrieval
- `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
The BEiT-3 **large** model can be finetuned on retrieval tasks using 2x16 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=16 --nnodes=2 --node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT run_beit3_finetuning.py \
--model beit3_large_patch16_384 \
--input_size 384 \
--task coco_retrieval \
--batch_size 96 \
--layer_decay 0.85 \
--lr 5e-5 \
--epochs 15 \
--warmup_epochs 3 \
--drop_path 0.2 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_itc_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.05 \
--seed 42 \
--save_ckpt_freq 5 \
--enable_deepspeed \
--checkpoint_activations
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `96*32 = 3072`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
- `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
- `--epochs`: 15 for COCO retrieval, 20 for Flickr30k retrieval
- `--warmup_epochs`: 3 for COCO retrieval, 5 for Flickr30k retrieval
- `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
## Example: Evaluate BEiT-3 Fine-tuned model on COCO Retrieval and Flickr30k Retrieval
- Get the results of our fine-tuned BEiT3-base model on retrieval tasks using a single GPU:
```bash
python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
--model beit3_base_patch16_384 \
--input_size 384 \
--task coco_retrieval \
--batch_size 16 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_384_coco_retrieval.pth \
--data_path /path/to/your_data \
--eval \
--dist_eval
```
- `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
- `--finetune`: **beit3_base_patch16_384_coco_retrieval.pth** for COCO retrieval, **beit3_base_patch16_384_f30k_retrieval.pth** for Flickr30k retrieval
- Get the results of our fine-tuned BEiT3-large model on retrieval tasks using a single GPU:
```bash
python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
--model beit3_large_patch16_384 \
--input_size 384 \
--task coco_retrieval \
--batch_size 16 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_384_coco_retrieval.pth \
--data_path /path/to/your_data \
--eval \
--dist_eval
```
- `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
- `--finetune`: **beit3_large_patch16_384_coco_retrieval.pth** for COCO retrieval, **beit3_large_patch16_384_f30k_retrieval.pth** for Flickr30k retrieval
# Fine-tuning BEiT-3 on VQAv2 (Visual Question Answering)
## Setup
1. [Setup environment](../README.md#setup).
2. Download COCO [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip), [2015 test images](http://images.cocodataset.org/zips/test2015.zip), annotations ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip)), and questions ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip), [test](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip)), then organize the dataset as following structure:
```
/path/to/your_data/
train2014/
COCO_train2014_000000000009.jpg
...
val2014/
COCO_val2014_000000000042.jpg
...
test2015/
COCO_test2015_000000000001.jpg
...
vqa/
v2_OpenEnded_mscoco_train2014_questions.json
v2_OpenEnded_mscoco_val2014_questions.json
v2_OpenEnded_mscoco_test2015_questions.json
v2_OpenEnded_mscoco_test-dev2015_questions.json
v2_mscoco_train2014_annotations.json
v2_mscoco_val2014_annotations.json
```
We then generate the index json files using the following command. [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) is the sentencepiece model used for tokenizing texts.
```
from datasets import VQAv2Dataset
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
VQAv2Dataset.make_dataset_index(
data_path="/path/to/your_data",
tokenizer=tokenizer,
annotation_data_path="/path/to/your_data/vqa",
)
```
## Example: Fine-tuning BEiT-3 on VQAv2 (Visual Question Answering)
The BEiT-3 **base** model can be finetuned on VQAv2 using 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_base_patch16_480 \
--input_size 480 \
--task vqav2 \
--batch_size 16 \
--layer_decay 1.0 \
--lr 3e-5 \
--update_freq 1 \
--randaug \
--epochs 10 \
--warmup_epochs 1 \
--drop_path 0.1 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.01 \
--seed 42 \
--save_ckpt_freq 5 \
--task_head_lr_weight 20 \
--opt_betas 0.9 0.98 \
--enable_deepspeed
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*16 = 128`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
The BEiT-3 **large** model can be finetuned on VQAv2 using 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_large_patch16_480 \
--input_size 480 \
--task vqav2 \
--batch_size 16 \
--layer_decay 1.0 \
--lr 2e-5 \
--update_freq 1 \
--randaug \
--epochs 10 \
--warmup_epochs 1 \
--drop_path 0.15 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_model \
--log_dir /path/to/save/your_model/log \
--weight_decay 0.01 \
--seed 42 \
--save_ckpt_freq 5 \
--task_head_lr_weight 20 \
--opt_betas 0.9 0.98 \
--enable_deepspeed \
--checkpoint_activations
```
- `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*16 = 128`.
- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
- `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
- `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
## Example: Evaluate BEiT-3 Finetuned model on VQAv2 (Visual Question Answering)
- Get the prediction file of the fine-tuned BEiT3-base model on VQAv2 test with 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_base_patch16_480 \
--input_size 480 \
--task vqav2 \
--batch_size 16 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_base_patch16_480_vqa.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_prediction \
--eval \
--dist_eval
```
- Get the prediction file of the fine-tuned BEiT3-large model on VQAv2 test with 8 V100-32GB:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
--model beit3_large_patch16_480 \
--input_size 480 \
--task vqav2 \
--batch_size 16 \
--sentencepiece_model /your_beit3_model_path/beit3.spm \
--finetune /your_beit3_model_path/beit3_large_patch16_480_vqa.pth \
--data_path /path/to/your_data \
--output_dir /path/to/save/your_prediction \
--eval \
--dist_eval
```
Please then submit the prediction file in the `output_dir` to the [evaluation server](https://eval.ai/web/challenges/challenge-page/830/overview) to obtain the VQAv2 test-dev and test-std results.
import re
contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
manual_map = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
articles = ["a", "an", "the"]
period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
comma_strip = re.compile("(\d)(\,)(\d)")
punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_word(token):
_token = token
for p in punct:
if (p + " " in token or " " + p in token) or (
re.search(comma_strip, token) != None
):
_token = _token.replace(p, "")
else:
_token = _token.replace(p, " ")
token = period_strip.sub("", _token, re.UNICODE)
_token = []
temp = token.lower().split()
for word in temp:
word = manual_map.setdefault(word, word)
if word not in articles:
_token.append(word)
for i, word in enumerate(_token):
if word in contractions:
_token[i] = contractions[word]
token = " ".join(_token)
token = token.replace(",", "")
return token
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.registry import register_model
import numpy as np
import utils
from modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config
class TwoLayerMLP(nn.Module):
def __init__(
self,
in_features,
hidden_features,
out_features,
norm_layer,
norm_input=True,
):
super().__init__()
self.norm1 = norm_layer(in_features) if norm_input else nn.Identity()
self.dense1 = nn.Linear(in_features, hidden_features)
self.norm2 = norm_layer(hidden_features)
self.act = nn.GELU()
self.dense2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = self.norm1(x)
x = self.dense1(x)
x = self.norm2(x)
x = self.act(x)
return self.dense2(x)
class Pooler(nn.Module):
def __init__(self, input_features, output_features, norm_layer):
super().__init__()
self.norm = norm_layer(input_features)
self.dense = nn.Linear(input_features, output_features)
self.activation = nn.Tanh()
def forward(self, x):
cls_rep = x[:, 0, :]
cls_rep = self.norm(cls_rep)
pooled_output = self.dense(cls_rep)
pooled_output = self.activation(pooled_output)
return pooled_output
class BEiT3ForVisualReasoning(BEiT3Wrapper):
def __init__(
self,
args,
num_classes,
norm_layer=nn.LayerNorm,
**kwargs
):
super(BEiT3ForVisualReasoning, self).__init__(args=args)
embed_dim = args.encoder_embed_dim
self.head = TwoLayerMLP(
in_features=embed_dim * 4,
hidden_features=embed_dim * 2,
out_features=num_classes,
norm_layer=norm_layer,
)
init_scale = 0.001
self.head.apply(self._init_weights)
if isinstance(self.head.dense1, nn.Linear):
self.head.dense1.weight.data.mul_(init_scale)
self.head.dense1.bias.data.mul_(init_scale)
if isinstance(self.head.dense2, nn.Linear):
self.head.dense2.weight.data.mul_(init_scale)
self.head.dense2.bias.data.mul_(init_scale)
def forward(self, image_a, image_b, text_description, padding_mask, **kwargs):
bsz, _ = text_description.size()
vision_input = torch.cat((image_a, image_b), dim=0)
language_input = torch.cat((text_description, text_description), dim=0)
padding_mask = torch.cat((padding_mask, padding_mask), dim=0)
outputs = self.beit3(
textual_tokens=language_input,
visual_tokens=vision_input,
text_padding_position=padding_mask,
)
x = outputs["encoder_out"]
multiway_split_position = outputs["multiway_split_position"]
vision_cls = x[:, 0, :]
language_cls = x[:, multiway_split_position, :]
cls_rep = torch.cat((vision_cls, language_cls), dim=-1)
a, b = torch.split(cls_rep, split_size_or_sections=[bsz, bsz], dim=0)
cls_rep = torch.cat((a, b), dim=-1)
return self.head(cls_rep)
class BEiT3ForImageClassification(BEiT3Wrapper):
def __init__(
self,
args,
num_classes,
norm_layer=nn.LayerNorm,
**kwargs
):
super(BEiT3ForImageClassification, self).__init__(args=args)
embed_dim = args.encoder_embed_dim
self.fc_norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.fc_norm.apply(self._init_weights)
self.head.apply(self._init_weights)
init_scale = 0.001
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
def forward(self, image, **kwargs):
x = self.beit3(textual_tokens=None, visual_tokens=image)["encoder_out"]
t = x[:, 1:, :]
cls_x = self.fc_norm(t.mean(1))
return self.head(cls_x)
class BEiT3ForCaptioning(BEiT3Wrapper):
def __init__(
self,
args,
**kwargs
):
super(BEiT3ForCaptioning, self).__init__(args=args)
embed_dim = args.encoder_embed_dim
self.mlm_head = nn.Linear(embed_dim, args.vocab_size)
self.mlm_head.apply(self._init_weights)
def forward(self, image, text_ids, padding_mask, language_masked_pos, text_len=None, incremental_state=None, **kwargs):
text_len = text_len if text_len is not None else text_ids.size(1)
image_len = self.beit3.vision_embed.num_position_embeddings()
max_len = text_len + image_len
uni_mask = torch.zeros((max_len, max_len), dtype=torch.long, device=text_ids.device)
i_start, i_end = 0, image_len
t_start, t_end = image_len, max_len
# triangle mask for caption to caption
uni_mask[t_start:t_end, t_start:t_end] = torch.tril(torch.ones(text_len, text_len, dtype=torch.long, device=text_ids.device))
# full attention for caption to image
uni_mask[t_start:t_end, i_start:i_end] = 1
# full attention for image to image
uni_mask[i_start:i_end, i_start:i_end] = 1
uni_mask = 1-uni_mask
if incremental_state is not None:
for idx in range(self.get_num_layers()):
if idx not in incremental_state:
incremental_state[idx] = {}
# for incremental decoding
positions = None
if image is None:
uni_mask = uni_mask[-2:]
padding_mask = None
# start position (2 (fairseq starts at 2) + cur_position) is equal to text_len
positions = torch.arange(text_len, text_ids.size(1) + text_len, device=text_ids.device).long().unsqueeze(0)
outputs = self.beit3(
textual_tokens=text_ids,
visual_tokens=image,
text_padding_position=padding_mask,
attn_mask=uni_mask,
incremental_state=incremental_state,
positions=positions,
)
if image is not None:
text_feats = outputs["encoder_out"][:, image_len:]
else:
text_feats = outputs["encoder_out"]
if language_masked_pos is not None:
text_feats = text_feats[language_masked_pos.bool()]
return self.mlm_head(text_feats), incremental_state
class BEiT3ForVisualQuestionAnswering(BEiT3Wrapper):
def __init__(
self,
args,
num_classes,
norm_layer=nn.LayerNorm,
**kwargs
):
super(BEiT3ForVisualQuestionAnswering, self).__init__(args=args)
embed_dim = args.encoder_embed_dim
self.pooler = Pooler(
input_features=embed_dim,
output_features=embed_dim,
norm_layer=norm_layer,
)
self.pooler.apply(self._init_weights)
self.head = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 2),
norm_layer(embed_dim * 2),
nn.GELU(),
nn.Linear(embed_dim * 2, num_classes),
)
self.head.apply(self._init_weights)
def forward(self, image, question, padding_mask, **kwargs):
outputs = self.beit3(
textual_tokens=question,
visual_tokens=image,
text_padding_position=padding_mask,
)
x = outputs["encoder_out"]
cls_rep = self.pooler(x)
return self.head(cls_rep)
class BEiT3ForRetrieval(BEiT3Wrapper):
def __init__(
self,
args,
**kwargs
):
super(BEiT3ForRetrieval, self).__init__(args=args)
embed_dim = args.encoder_embed_dim
self.language_head = nn.Linear(embed_dim, embed_dim, bias=False)
self.vision_head = nn.Linear(embed_dim, embed_dim, bias=False)
self.language_head.apply(self._init_weights)
self.vision_head.apply(self._init_weights)
self.criterion = utils.ClipLoss(
rank=utils.get_rank(),
world_size=utils.get_world_size(),
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, image=None, text_description=None, padding_mask=None, only_infer=False, **kwargs):
if image is not None:
outputs = self.beit3(
textual_tokens=None,
visual_tokens=image,
text_padding_position=None,
)
x = outputs["encoder_out"]
vision_cls = self.vision_head(x[:, 0, :])
vision_cls = F.normalize(vision_cls, dim=-1)
else:
vision_cls = None
if text_description is not None:
outputs = self.beit3(
textual_tokens=text_description,
visual_tokens=None,
text_padding_position=padding_mask,
)
x = outputs["encoder_out"]
language_cls = self.language_head(x[:, 0, :])
language_cls = F.normalize(language_cls, dim=-1)
else:
language_cls = None
if only_infer:
return vision_cls, language_cls
else:
loss, logits_per_image, logits_per_text = self.criterion(
vision_cls, language_cls, self.logit_scale.exp())
return loss, vision_cls, language_cls
@register_model
def beit3_base_patch16_224_imageclassification(pretrained=False, **kwargs):
args = _get_base_config(**kwargs)
args.normalize_output = False
model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs)
return model
@register_model
def beit3_large_patch16_224_imageclassification(pretrained=False, **kwargs):
args = _get_large_config(**kwargs)
args.normalize_output = False
model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs)
return model
@register_model
def beit3_base_patch16_224_nlvr2(pretrained=False, **kwargs):
args = _get_base_config(**kwargs)
model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs)
return model
@register_model
def beit3_large_patch16_224_nlvr2(pretrained=False, **kwargs):
args = _get_large_config(**kwargs)
model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs)
return model
@register_model
def beit3_base_patch16_384_vqav2(pretrained=False, **kwargs):
args = _get_base_config(img_size=384, **kwargs)
args.normalize_output = False
model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
return model
@register_model
def beit3_base_patch16_480_vqav2(pretrained=False, **kwargs):
args = _get_base_config(img_size=480, **kwargs)
args.normalize_output = False
model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
return model
@register_model
def beit3_large_patch16_384_vqav2(pretrained=False, **kwargs):
args = _get_large_config(img_size=384, **kwargs)
args.normalize_output = False
model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
return model
@register_model
def beit3_large_patch16_480_vqav2(pretrained=False, **kwargs):
args = _get_large_config(img_size=480, **kwargs)
args.normalize_output = False
model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
return model
@register_model
def beit3_large_patch16_768_vqav2(pretrained=False, **kwargs):
args = _get_large_config(img_size=768, **kwargs)
args.normalize_output = False
model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
return model
@register_model
def beit3_base_patch16_224_captioning(pretrained=False, **kwargs):
args = _get_base_config(**kwargs)
model = BEiT3ForCaptioning(args, **kwargs)
return model
@register_model
def beit3_base_patch16_480_captioning(pretrained=False, **kwargs):
args = _get_base_config(img_size=480, **kwargs)
model = BEiT3ForCaptioning(args, **kwargs)
return model
@register_model
def beit3_large_patch16_480_captioning(pretrained=False, **kwargs):
args = _get_large_config(img_size=480, **kwargs)
model = BEiT3ForCaptioning(args, **kwargs)
return model
@register_model
def beit3_base_patch16_224_retrieval(pretrained=False, **kwargs):
args = _get_base_config(**kwargs)
model = BEiT3ForRetrieval(args, **kwargs)
return model
@register_model
def beit3_base_patch16_384_retrieval(pretrained=False, **kwargs):
args = _get_base_config(img_size=384, **kwargs)
model = BEiT3ForRetrieval(args, **kwargs)
return model
@register_model
def beit3_large_patch16_384_retrieval(pretrained=False, **kwargs):
args = _get_large_config(img_size=384, **kwargs)
model = BEiT3ForRetrieval(args, **kwargs)
return model
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
import math
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
from torchscale.model.BEiT3 import BEiT3
from torchscale.architecture.config import EncoderConfig
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
def _get_base_config(
img_size=224, patch_size=16, drop_path_rate=0,
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
):
return EncoderConfig(
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12,
encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12,
checkpoint_activations=checkpoint_activations,
)
def _get_large_config(
img_size=224, patch_size=16, drop_path_rate=0,
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
):
return EncoderConfig(
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16,
encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24,
checkpoint_activations=checkpoint_activations,
)
class BEiT3Wrapper(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
self.beit3 = BEiT3(args)
self.apply(self._init_weights)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def get_num_layers(self):
return self.beit3.encoder.num_layers
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
from torch import optim as optim
from timm.optim.lookahead import Lookahead
import json
def get_num_layer_for_vit(var_name, num_max_layer):
if "embed" in var_name:
return 0
elif var_name in (
"cls_token", "mask_token", "pos_embed", "language_pos_embed",
"word_embeddings.weight", "vision_cls_token", "vision_pos_embed"
):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("rel_pos_bias"):
return num_max_layer - 1
elif "layers." in var_name:
layer_id = int(var_name.split('layers.')[1].split('.')[0])
return layer_id + 1
else:
return num_max_layer - 1
def get_is_head_flag_for_vit(var_name, num_max_layer):
if var_name.startswith("head"):
return 1
# elif var_name.startswith("pooler"):
# return 1
else:
return 0
class LayerDecayValueAssigner(object):
def __init__(self, values, scale_handler=None):
self.scale_handler = scale_handler or get_num_layer_for_vit
self.values = values
def get_scale(self, layer_id):
return self.values[layer_id]
def get_layer_id(self, var_name):
return self.scale_handler(var_name, len(self.values))
# The implementation code is modified from Timm (https://github.com/huggingface/pytorch-image-models/tree/main/timm
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if get_num_layer is not None:
layer_id = get_num_layer(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if get_layer_scale is not None:
scale = get_layer_scale(layer_id)
else:
scale = 1.
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values())
def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
if weight_decay and filter_bias_and_bn:
skip = {}
if skip_list is not None:
skip = skip_list
elif hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
weight_decay = 0.
else:
parameters = model.parameters()
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
else:
raise ValueError("Invalid optimizer")
if len(opt_split) > 1:
if opt_split[0] == 'lookahead':
optimizer = Lookahead(optimizer)
return optimizer
import cv2
import numpy as np
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0: return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
'''
same output as PIL.ImageEnhance.Color
'''
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = (
np.float32([
[0.886, -0.114, -0.114],
[-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]])
)
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(
el - mean) * factor + mean
for el in range(256)
]).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def brightness_func(img, factor):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
'Identity': identity_func,
'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
'ShearX': shear_x_func,
'TranslateX': translate_x_func,
'TranslateY': translate_y_func,
'Posterize': posterize_func,
'ShearY': shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity': none_level_to_args,
'AutoContrast': none_level_to_args,
'Equalize': none_level_to_args,
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize': solarize_level_to_args(MAX_LEVEL),
'Color': enhance_level_to_args(MAX_LEVEL),
'Contrast': enhance_level_to_args(MAX_LEVEL),
'Brightness': enhance_level_to_args(MAX_LEVEL),
'Sharpness': enhance_level_to_args(MAX_LEVEL),
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
'TranslateX': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'TranslateY': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'Posterize': posterize_level_to_args(MAX_LEVEL),
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
if __name__ == '__main__':
a = RandomAugment()
img = np.random.randn(32, 32, 3)
a(img)
timm==0.4.12
Pillow
blobfile
mypy
numpy
pytest
requests
einops
tensorboardX
scipy
ftfy
opencv-python
sentencepiece
pyarrow
torchmetrics==0.7.3
transformers
pycocotools
pycocoevalcap
torchscale==0.2.0
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.data.mixup import Mixup
from timm.models import create_model
from timm.utils import ModelEma
from optim_factory import create_optimizer, get_parameter_groups, \
LayerDecayValueAssigner, get_is_head_flag_for_vit
from engine_for_finetuning import train_one_epoch, get_handler, evaluate
from datasets import create_downstream_dataset
from utils import NativeScalerWithGradNormCount as NativeScaler
import utils
import modeling_finetune
def get_args():
parser = argparse.ArgumentParser('BEiT fine-tuning and evaluation script for image classification', add_help=False)
# Model parameters
parser.add_argument('--model', default='beit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--task', type=str, required=True,
choices=['nlvr2', 'vqav2', 'flickr30k', 'coco_retrieval', 'coco_captioning', 'nocaps', 'imagenet'],
help='Name of task to fine-tuning')
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--checkpoint_activations', action='store_true', default=None,
help='Enable checkpointing to save your memory.')
parser.add_argument('--sentencepiece_model', type=str, required=True,
help='Sentencepiece model path for the pretrained model.')
parser.add_argument('--vocab_size', type=int, default=64010)
parser.add_argument('--num_max_bpe_tokens', type=int, default=64)
parser.add_argument('--model_ema', action='store_true', default=False)
parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: 0.9, 0.999, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
help='learning rate (default: 5e-4)')
parser.add_argument('--layer_decay', type=float, default=0.9)
parser.add_argument('--task_head_lr_weight', type=float, default=0)
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--eval_batch_size', default=None, type=int)
parser.add_argument('--epochs', default=20, type=int)
parser.add_argument('--update_freq', default=1, type=int)
parser.add_argument('--save_ckpt_freq', default=5, type=int)
# Augmentation parameters
parser.add_argument('--randaug', action='store_true', default=False)
parser.add_argument('--train_interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
# Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--model_key', default='model|module', type=str)
parser.add_argument('--model_prefix', default='', type=str)
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default=None,
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
parser.set_defaults(auto_resume=True)
parser.add_argument('--save_ckpt', action='store_true')
parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
parser.set_defaults(save_ckpt=True)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# parameter for dump predictions (VQA, COCO captioning, NoCaps)
parser.add_argument('--task_cache_path', default=None, type=str)
# parameter for imagenet finetuning
parser.add_argument('--nb_classes', default=1000, type=int,
help='number of the classification types')
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# augmentation parameters for imagenet finetuning
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# evaluation parameters for imagenet
parser.add_argument('--crop_pct', type=float, default=None)
# random Erase params for imagenet finetuning
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# parameter for captioning finetuning
parser.add_argument('--captioning_mask_prob', type=float, default=0.6)
parser.add_argument('--drop_worst_ratio', type=float, default=0.2)
parser.add_argument('--drop_worst_after', type=int, default=12000)
parser.add_argument('--num_beams', type=int, default=3)
parser.add_argument('--length_penalty', type=float, default=0.6)
# label smoothing for imagenet and captioning
parser.add_argument('--label_smoothing', type=float, default=0.1)
# deepspeed parameters
parser.add_argument('--enable_deepspeed', action='store_true', default=False)
parser.add_argument('--initial_scale_power', type=int, default=16)
parser.add_argument('--zero_stage', default=0, type=int,
help='ZeRO optimizer stage (default: 0)')
known_args, _ = parser.parse_known_args()
if known_args.enable_deepspeed:
try:
import deepspeed
from deepspeed import DeepSpeedConfig
parser = deepspeed.add_config_arguments(parser)
ds_init = deepspeed.initialize
except:
print("Please 'pip install deepspeed==0.4.0'")
exit(0)
else:
ds_init = None
return parser.parse_args(), ds_init
def main(args, ds_init):
utils.init_distributed_mode(args)
if ds_init is not None:
utils.create_ds_config(args)
if args.task_cache_path is None:
args.task_cache_path = args.output_dir
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# random.seed(seed)
cudnn.benchmark = True
if utils.get_rank() == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train, data_loader_val = create_downstream_dataset(args)
if not args.model.endswith(args.task):
if args.task in ("flickr30k", "coco_retrieval"):
model_config = "%s_retrieval" % args.model
elif args.task in ("coco_captioning", "nocaps"):
model_config = "%s_captioning" % args.model
elif args.task in ("imagenet"):
model_config = "%s_imageclassification" % args.model
else:
model_config = "%s_%s" % (args.model, args.task)
else:
model_config = args.model
print("model_config = %s" % model_config)
model = create_model(
model_config,
pretrained=False,
drop_path_rate=args.drop_path,
vocab_size=args.vocab_size,
checkpoint_activations=args.checkpoint_activations,
)
if args.finetune:
utils.load_model_and_may_interpolate(args.finetune, model, args.model_key, args.model_prefix)
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
print("Using EMA with decay = %.8f" % args.model_ema_decay)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params:', n_parameters)
total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
num_training_steps_per_epoch = len(data_loader_train.dataset) // total_batch_size
print("LR = %.8f" % args.lr)
print("Batch size = %d" % total_batch_size)
print("Update frequent = %d" % args.update_freq)
print("Number of training examples = %d" % len(data_loader_train.dataset))
print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
num_layers = model_without_ddp.get_num_layers()
if args.layer_decay < 1.0:
lrs = list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
assigner = LayerDecayValueAssigner(lrs)
elif args.task_head_lr_weight > 1:
assigner = LayerDecayValueAssigner([1.0, args.task_head_lr_weight], scale_handler=get_is_head_flag_for_vit)
else:
assigner = None
if assigner is not None:
print("Assigned values = %s" % str(assigner.values))
skip_weight_decay_list = model.no_weight_decay()
if args.distributed:
torch.distributed.barrier()
if args.enable_deepspeed:
loss_scaler = None
optimizer_params = get_parameter_groups(
model, args.weight_decay, skip_weight_decay_list,
assigner.get_layer_id if assigner is not None else None,
assigner.get_scale if assigner is not None else None)
model, optimizer, _, _ = ds_init(
args=args, model=model, model_parameters=optimizer_params,
dist_init_required=not args.distributed,
)
print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
assert model.gradient_accumulation_steps() == args.update_freq
else:
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
optimizer = create_optimizer(
args, model_without_ddp, skip_list=skip_weight_decay_list,
get_num_layer=assigner.get_layer_id if assigner is not None else None,
get_layer_scale=assigner.get_scale if assigner is not None else None)
loss_scaler = NativeScaler()
lr_schedule_values = utils.cosine_scheduler(
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
utils.auto_load_model(
args=args, model=model, model_without_ddp=model_without_ddp,
optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
task_handler = get_handler(args)
# mixup for imagenet
mixup_fn = None
if args.task in ["imagenet", "in1k"]:
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.label_smoothing, num_classes=args.nb_classes)
if args.eval:
data_loader_test = create_downstream_dataset(args, is_eval=True)
if args.task in ["nlvr2", "flickr30k", "coco_retrieval", "imagenet"]:
ext_test_stats, task_key = evaluate(data_loader_test, model, device, task_handler)
print(f"Accuracy of the network on the {len(data_loader_test.dataset)} test images: {ext_test_stats[task_key]:.3f}%")
exit(0)
elif args.task == "vqav2":
result, _ = evaluate(data_loader_test, model, device, task_handler)
utils.dump_predictions(args, result, "vqav2_test")
exit(0)
elif args.task in ["coco_captioning", "nocaps"]:
predictions, _ = evaluate(data_loader_test, model, device, task_handler)
prediction_file = utils.dump_predictions(args, predictions, "{}_test".format(args.task))
if utils.is_main_process() and args.task == "coco_captioning":
captioning_result = utils.coco_caption_eval(args.output_dir, prediction_file, "{}_test".format(args.task))
result_file = os.path.join(args.output_dir, f"{args.task}_result.json")
print(json.dumps(captioning_result))
utils.write_result_to_jsonl(captioning_result, result_file)
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
if log_writer is not None:
log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
train_stats = train_one_epoch(
model, data_loader_train, optimizer, device, task_handler, epoch,
epoch * num_training_steps_per_epoch, lr_schedule_values, loss_scaler,
args.clip_grad, args.update_freq, model_ema, log_writer, args.task, mixup_fn,
)
if args.output_dir and args.save_ckpt:
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
if data_loader_val is not None:
if args.task not in ["coco_captioning", "nocaps"]:
test_stats, task_key = evaluate(data_loader_val, model, device, task_handler)
else:
predictions, _ = evaluate(data_loader_val, model, device, task_handler)
prediction_file = utils.dump_predictions(args, predictions, f"{args.task}_val_e{epoch}")
result_file = os.path.join(args.output_dir, f"{args.task}_result_val_e{epoch}.json")
task_key = "CIDEr"
if utils.is_main_process():
test_stats = utils.coco_caption_eval(args.output_dir, prediction_file, "{}_val".format(args.task))
utils.write_result_to_jsonl(test_stats, result_file)
torch.distributed.barrier()
if not utils.is_main_process():
test_stats = utils.read_result_from_jsonl(result_file)
print(f"Performance of the network on the {len(data_loader_val.dataset)} val images: {test_stats[task_key]:.1f}%")
if max_accuracy < test_stats[task_key]:
max_accuracy = test_stats[task_key]
if args.output_dir and args.save_ckpt:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
print(f'Max performance: {max_accuracy:.2f}%')
if log_writer is not None:
log_writer.update(acc=test_stats[task_key], head="perf", step=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'val_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
else:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
# **{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
opts, ds_init = get_args()
if opts.output_dir:
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
main(opts, ds_init)
#!/bin/bash/
export HIP_VISIBLE_DEVICES=0,1,2,3 # 自行修改为训练的卡号和数量
export HSA_FORCE_FINE_GRAIN_PCIE=1
export USE_MIOPEN_BATCHNORM=1
python -m torch.distributed.launch --nproc_per_node=4 run_beit3_finetuning.py \
--model beit3_base_patch16_480 \
--input_size 480 \
--task coco_captioning \
--batch_size 32 \
--layer_decay 1.0 \
--lr 4e-5 \
--randaug \
--epochs 10 \
--warmup_epochs 1 \
--drop_path 0.1 \
--sentencepiece_model ./pretrained_models/beit3.spm \
--finetune ./pretrained_models/beit3_base_patch16_224.pth \
--data_path /home/data/coco2014 \
--output_dir ./save_models/ \
--log_dir ./logs \
--weight_decay 0.05 \
--seed 42 \
--save_ckpt_freq 5 \
--num_max_bpe_tokens 32 \
--captioning_mask_prob 0.7 \
--drop_worst_after 12000 \
--dist_eval \
--checkpoint_activations \
--enable_deepspeed
\ No newline at end of file
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
import datetime
import io
import os
import math
import time
import json
import argparse
import numpy as np
from pathlib import Path
from collections import defaultdict, deque
from timm.utils import get_state_dict
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch._six import inf
from torchmetrics import Metric
from tensorboardX import SummaryWriter
def bool_flag(s):
"""
Parse boolean arguments from the command line.
"""
FALSY_STRINGS = {"off", "false", "0"}
TRUTHY_STRINGS = {"on", "true", "1"}
if s.lower() in FALSY_STRINGS:
return False
elif s.lower() in TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
class TensorboardLogger(object):
def __init__(self, log_dir):
self.writer = SummaryWriter(logdir=log_dir)
self.step = 0
def set_step(self, step=None):
if step is not None:
self.step = step
else:
self.step += 1
def update(self, head='scalar', step=None, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
def flush(self):
self.writer.flush()
def _load_checkpoint_for_ema(model_ema, checkpoint):
"""
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
"""
mem_file = io.BytesIO()
torch.save(checkpoint, mem_file)
mem_file.seek(0)
model_ema._load_checkpoint(mem_file)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def _get_rank_env():
if "RANK" in os.environ:
return int(os.environ["RANK"])
else:
return int(os.environ['OMPI_COMM_WORLD_RANK'])
def _get_local_rank_env():
if "LOCAL_RANK" in os.environ:
return int(os.environ["LOCAL_RANK"])
else:
return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
def _get_world_size_env():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
else:
return int(os.environ['OMPI_COMM_WORLD_SIZE'])
# The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = _get_rank_env()
args.world_size = _get_world_size_env() # int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = _get_local_rank_env()
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank,
timeout=datetime.timedelta(0, 7200)
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix=prefix)
warn_missing_keys = []
ignore_missing_keys = []
for key in missing_keys:
keep_flag = True
for ignore_key in ignore_missing.split('|'):
if ignore_key in key:
keep_flag = False
break
if keep_flag:
warn_missing_keys.append(key)
else:
ignore_missing_keys.append(key)
missing_keys = warn_missing_keys
if len(missing_keys) > 0:
print("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
print("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(ignore_missing_keys) > 0:
print("Ignored weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, ignore_missing_keys))
if len(error_msgs) > 0:
print('\n'.join(error_msgs))
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
start_warmup_value=0, warmup_steps=-1, sched_type="cos"):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_steps > 0:
warmup_iters = warmup_steps
print("Set warmup steps = %d" % warmup_iters)
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
if sched_type == "cos":
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = np.array([
final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
elif sched_type == "linear":
schedule = np.linspace(base_value, final_value, epochs * niter_per_ep - warmup_iters)
else:
raise NotImplementedError()
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
output_dir = Path(args.output_dir)
if loss_scaler is not None:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
if model_ema is not None:
to_save['model_ema'] = get_state_dict(model_ema)
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch, "args": args}
if model_ema is not None:
client_state['model_ema'] = get_state_dict(model_ema)
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch, client_state=client_state)
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
output_dir = Path(args.output_dir)
if loss_scaler is not None:
# torch.amp
if args.auto_resume and len(args.resume) == 0:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
print("Auto resume checkpoint: %s" % args.resume)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if hasattr(args, 'model_ema') and args.model_ema:
_load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
else:
# deepspeed, only support '--auto_resume'.
if args.auto_resume:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
print("Auto resume checkpoint: %d" % latest_ckpt)
_, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
args.start_epoch = client_states['epoch'] + 1
if model_ema is not None:
if args.model_ema:
_load_checkpoint_for_ema(model_ema, client_states['model_ema'])
# The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git)
def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix):
if ckpt_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
ckpt_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(ckpt_path, map_location='cpu')
print("Load ckpt from %s" % ckpt_path)
checkpoint_model = None
for model_key in model_key.split('|'):
if model_key in checkpoint:
checkpoint_model = checkpoint[model_key]
print("Load state_dict by model_key = %s" % model_key)
break
if checkpoint_model is None:
checkpoint_model = checkpoint
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
for pos_embed_key in ("vision_pos_embed", "pos_embed", "beit3.encoder.embed_positions.A.weight"):
if pos_embed_key in checkpoint_model:
pos_embed_checkpoint = checkpoint_model[pos_embed_key]
embedding_size = pos_embed_checkpoint.shape[-1]
if pos_embed_key == "beit3.encoder.embed_positions.A.weight":
# being consistent with Fairseq, which starts from 2 for position embedding
torchscale_model = True
num_patches = model.beit3.vision_embed.num_patches
num_extra_tokens = model.beit3.vision_embed.num_position_embeddings() + 2 - num_patches
else:
torchscale_model = False
num_patches = model.patch_embed.num_patches
num_extra_tokens = getattr(model, pos_embed_key).shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
if torchscale_model:
extra_tokens = pos_embed_checkpoint[:num_extra_tokens].unsqueeze(0)
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[num_extra_tokens:]
else:
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
if torchscale_model:
new_pos_embed = new_pos_embed.squeeze(0)
checkpoint_model[pos_embed_key] = new_pos_embed
load_state_dict(model, checkpoint_model, prefix=model_prefix)
def create_ds_config(args):
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * args.update_freq * get_world_size(),
"train_micro_batch_size_per_gpu": args.batch_size,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"adam_w_mode": True,
"params": {
"lr": args.lr,
"weight_decay": args.weight_decay,
"bias_correction": True,
"betas": [
args.opt_betas[0],
args.opt_betas[1]
],
"eps": args.opt_eps
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": getattr(args, "initial_scale_power", 12),
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"amp": {
"enabled": False,
"opt_level": "O2"
}
}
if args.clip_grad is not None:
ds_config.update({'gradient_clipping': args.clip_grad})
if args.zero_stage == 1:
ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}})
elif args.zero_stage > 1:
raise NotImplementedError()
writer.write(json.dumps(ds_config, indent=2))
def merge_batch_tensors_by_dict_key(batch):
batch_tensors = {}
for tensor_key in batch[0]:
if isinstance(batch[0][tensor_key], torch.Tensor):
batch_tensors[tensor_key] = torch.stack([d[tensor_key] for d in batch])
else:
batch_tensors[tensor_key] = torch.tensor([d[tensor_key] for d in batch], dtype=torch.long)
return batch_tensors
def get_loss_scale_for_deepspeed(model):
optimizer = model.optimizer
loss_scale = None
if hasattr(optimizer, 'loss_scale'):
loss_scale = optimizer.loss_scale
elif hasattr(optimizer, 'cur_scale'):
loss_scale = optimizer.cur_scale
return loss_scale
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]
def gather_features(
image_features,
text_features,
):
gathered_image_features = GatherLayer.apply(image_features)
gathered_text_features = GatherLayer.apply(text_features)
all_image_features = torch.cat(gathered_image_features)
all_text_features = torch.cat(gathered_text_features)
return all_image_features, all_text_features
# The implementation code is modified from open_clip (https://github.com/mlfoundations/open_clip.git)
class ClipLoss(nn.Module):
def __init__(
self,
cache_labels=False,
rank=0,
world_size=1,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
# cache state
self.prev_num_logits = 0
self.labels = {}
def forward(self, image_features, text_features, logit_scale):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features
)
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
# calculated ground-truth and cache if enabled
num_logits = logits_per_image.shape[0]
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss, logits_per_image, logits_per_text
def write_result_to_jsonl(test_stats, result_file):
with open(result_file, mode="w", encoding="utf-8") as writer:
writer.write(json.dumps(test_stats, indent=None))
def read_result_from_jsonl(result_file):
with open(result_file, mode="r", encoding="utf-8") as reader:
return json.load(reader)
# The implementation code is from ViLT (https://github.com/dandelin/ViLT.git)
class VQAScore(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, logits, target):
logits, target = (
logits.detach().float().to(self.score.device),
target.detach().float().to(self.score.device),
)
logits = torch.max(logits, 1)[1]
one_hots = torch.zeros(*target.size()).to(target)
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = one_hots * target
self.score += scores.sum()
self.total += len(logits)
def compute(self):
return self.score / self.total
class BertCaptioningLoss(nn.Module):
def __init__(self, label_smoothing, drop_worst_ratio, drop_worst_after):
super().__init__()
self.label_smoothing = label_smoothing
self.drop_worst_ratio = drop_worst_ratio
self.drop_worst_after = drop_worst_after
self.log_soft = nn.LogSoftmax(dim=1)
self.kl = nn.KLDivLoss(reduction='none')
self.iter = 0
def forward(self, logits, target, iter):
eps = self.label_smoothing
n_class = logits.size(1)
one_hot = torch.zeros_like(logits).scatter(1, target.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = self.log_soft(logits)
loss = self.kl(log_prb, one_hot).sum(1)
if self.drop_worst_ratio > 0 and iter > self.drop_worst_after:
loss, _ = torch.topk(loss,
k=int(loss.shape[0] * (1-self.drop_worst_ratio)),
largest=False)
loss = loss.mean()
return loss
class BeamHypotheses(object):
def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.n_hyp = n_hyp
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.n_hyp:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.n_hyp:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
def dump_predictions(args, result, file_suffix):
global_rank = get_rank()
jsons = None
if global_rank >= 0:
output_file = os.path.join(args.task_cache_path, f"submit_{global_rank}_{file_suffix}.json")
with open(output_file, "w") as fp:
json.dump(result, fp, indent=2)
torch.distributed.barrier()
if global_rank == 0:
world_size = get_world_size()
jsons = []
for i in range(world_size):
each_file = os.path.join(args.task_cache_path, f"submit_{i}_{file_suffix}.json")
with open(each_file, "r") as fp:
jsons += json.load(fp)
new_jsons = []
res_dict = dict()
if args.task in ["coco_captioning", "nocaps"]:
qid_key = "image_id"
else:
# for VQAv2
qid_key = "question_id"
for item in jsons:
if item[qid_key] in res_dict:
continue
new_jsons.append(item)
res_dict[item[qid_key]] = item
jsons = new_jsons
torch.distributed.barrier()
os.remove(output_file)
else:
jsons = result
result_file = os.path.join(args.output_dir, f"submit_{file_suffix}.json")
if jsons is not None:
with open(result_file, "w") as fp:
json.dump(jsons, fp, indent=2)
print("Infer %d examples into %s" % (len(jsons), result_file))
return result_file
# The evaluation code is from BLIP (https://github.com/salesforce/BLIP)
# For nocaps, please submit the prediction file to the evaluate server (https://eval.ai/web/challenges/challenge-page/355/overview) to obtain the final results
def coco_caption_eval(gt_dir, results_file, split):
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
from torchvision.datasets.utils import download_url
urls = {'coco_captioning_val': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
'coco_captioning_test': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json',
'nocaps_val': 'https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nocaps/nocaps_val_gt.json?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D'}
filenames = {'coco_captioning_val':'coco_karpathy_val_gt.json',
'coco_captioning_test':'coco_karpathy_test_gt.json',
'nocaps_val':'nocaps_val_gt.json'}
download_url(urls[split], gt_dir)
annotation_file = os.path.join(gt_dir, filenames[split])
# create coco object and coco_result object
coco = COCO(annotation_file)
coco_result = coco.loadRes(results_file)
# create coco_eval object by taking coco and coco_result
coco_eval = COCOEvalCap(coco, coco_result)
# evaluate results
# SPICE will take a few minutes the first time, but speeds up due to caching
coco_eval.evaluate()
res_dict = dict()
for metric, score in coco_eval.eval.items():
res_dict[metric] = score
return res_dict
#!/bin/bash/
export HSA_FORCE_FINE_GRAIN_PCIE=1
export USE_MIOPEN_BATCHNORM=1
python -m torch.distributed.launch --nproc_per_node=4 run_beit3_finetuning.py \
--model beit3_base_patch16_480 \
--input_size 480 \
--task coco_captioning \
--batch_size 16 \
--sentencepiece_model ./pretrained_models/beit3.spm \
--finetune ./pretrained_models/beit3_base_patch16_480_coco_captioning.pth \
--data_path ../../data/coco2014/ \
--output_dir ./save_models \
--eval \
--dist_eval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment