"...tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "a65ba42c50b9c5cd2318342911b30ae4460ecba6"
Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
import math
import os
import random
from dataclasses import dataclass
from typing import List, Tuple, Dict
import datasets
import torch
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding
from .arguments import DataArguments
class TrainDatasetForCE(Dataset):
def __init__(
self,
args: DataArguments,
tokenizer: PreTrainedTokenizer,
):
if os.path.isdir(args.train_data):
train_datasets = []
for file in os.listdir(args.train_data):
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file),
split='train')
train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets)
else:
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train')
self.tokenizer = tokenizer
self.args = args
self.total_len = len(self.dataset)
def create_one_example(self, qry_encoding: str, doc_encoding: str):
item = self.tokenizer.encode_plus(
qry_encoding,
doc_encoding,
truncation=True,
max_length=self.args.max_len,
padding=False,
)
return item
def __len__(self):
return self.total_len
def __getitem__(self, item) -> List[BatchEncoding]:
query = self.dataset[item]['query']
pos = random.choice(self.dataset[item]['pos'])
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
else:
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
batch_data = []
batch_data.append(self.create_one_example(query, pos))
for neg in negs:
batch_data.append(self.create_one_example(query, neg))
return batch_data
@dataclass
class GroupCollator(DataCollatorWithPadding):
def __call__(
self, features
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
if isinstance(features[0], list):
features = sum(features, [])
return super().__call__(features)
import logging
import torch
from torch import nn
from transformers import AutoModelForSequenceClassification, PreTrainedModel, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput
from .arguments import ModelArguments, DataArguments
logger = logging.getLogger(__name__)
class CrossEncoder(nn.Module):
def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments,
train_args: TrainingArguments):
super().__init__()
self.hf_model = hf_model
self.model_args = model_args
self.train_args = train_args
self.data_args = data_args
self.config = self.hf_model.config
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.register_buffer(
'target_label',
torch.zeros(self.train_args.per_device_train_batch_size, dtype=torch.long)
)
def gradient_checkpointing_enable(self, **kwargs):
self.hf_model.gradient_checkpointing_enable(**kwargs)
def forward(self, batch):
ranker_out: SequenceClassifierOutput = self.hf_model(**batch, return_dict=True)
logits = ranker_out.logits
if self.training:
scores = logits.view(
self.train_args.per_device_train_batch_size,
self.data_args.train_group_size
)
loss = self.cross_entropy(scores, self.target_label)
return SequenceClassifierOutput(
loss=loss,
**ranker_out,
)
else:
return ranker_out
@classmethod
def from_pretrained(
cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments,
*args, **kwargs
):
hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
reranker = cls(hf_model, model_args, data_args, train_args)
return reranker
def save_pretrained(self, output_dir: str):
state_dict = self.hf_model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.hf_model.save_pretrained(output_dir, state_dict=state_dict)
import logging
import os
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer, TrainingArguments
from transformers import (
HfArgumentParser,
set_seed,
)
from .arguments import ModelArguments, DataArguments
from .data import TrainDatasetForCE, GroupCollator
from .modeling import CrossEncoder
from .trainer import CETrainer
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
set_seed(training_args.seed)
num_labels = 1
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=model_args.cache_dir,
)
_model_class = CrossEncoder
model = _model_class.from_pretrained(
model_args, data_args, training_args,
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
train_dataset = TrainDatasetForCE(data_args, tokenizer=tokenizer)
_trainer_class = CETrainer
trainer = _trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=GroupCollator(tokenizer),
tokenizer=tokenizer
)
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
trainer.train()
trainer.save_model()
if __name__ == "__main__":
main()
import logging
import os
from typing import Optional
import torch
from transformers.trainer import Trainer
from .modeling import CrossEncoder
logger = logging.getLogger(__name__)
class CETrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not hasattr(self.model, 'save_pretrained'):
raise NotImplementedError(f'MODEL {self.model.__class__.__name__} ' f'does not support save_pretrained interface')
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def compute_loss(self, model: CrossEncoder, inputs):
return model(inputs)['loss']
<h1 align="center">Visualized BGE</h1>
<p align="center">
<a href="https://arxiv.org/abs/2406.04292">
<img alt="Build" src="http://img.shields.io/badge/cs.CV-arXiv%3A2406.04292-B31B1B.svg">
</a>
<a href="https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/visual">
<img alt="Build" src="https://img.shields.io/badge/Github-VISTA Code-blue">
</a>
<a href="https://huggingface.co/BAAI/bge-visualized">
<img alt="Build" src="https://img.shields.io/badge/🤗 Model-VISTA Model-yellow">
</p>
<p align="center">
</a>
<a href="https://huggingface.co/datasets/JUNJIE99/VISTA_S2">
<img alt="Build" src="https://img.shields.io/badge/🤗 Dataset-VISTA S2 Training Dataset-yellow">
</a>
<a href="https://huggingface.co/datasets/JUNJIE99/VISTA_Evaluation">
<img alt="Build" src="https://img.shields.io/badge/🤗 Dataset-Zero_Shot Multimodal Retrieval Dataset-yellow">
</a>
</p>
## 🔔 News
**[2024.3.18] We have released our code and model.**
**[2024.6.7] We have released our paper. [Arxiv Link](https://arxiv.org/abs/2406.04292)**
**[2024.6.13] We have released [VISTA-S2 dataset](https://huggingface.co/datasets/JUNJIE99/VISTA_S2), a hybrid multi-modal dataset consisting of over 500,000 instances for multi-modal training (Stage-2 training in our paper).**
## Introduction
In this project, we introduce Visualized-BGE, a universal multi-modal embedding model. By incorporating image token embedding into the BGE Text Embedding framework, Visualized-BGE gains the flexibility to process multi-modal data that goes beyond just text. Visualized-BGE is mainly used for hybrid modal retrieval tasks, including but not limited to:
- Multi-Modal Knowledge Retrieval (query: text; candidate: image-text pairs, text, or image) e.g. [WebQA](https://github.com/WebQnA/WebQA)
- Composed Image Retrieval (query: image-text pair; candidate: images) e.g. [CIRR](https://github.com/Cuberick-Orion/CIRR), [FashionIQ](https://github.com/XiaoxiaoGuo/fashion-iq)
- Knowledge Retrieval with Multi-Modal Queries (query: image-text pair; candidate: texts) e.g. [ReMuQ](https://github.com/luomancs/ReMuQ)
Moreover, Visualized BGE fully preserves the strong text embedding capabilities of the original BGE model : )
## Specs
### Model
| **Model Name** | **Dimension** | **Text Embedding Model** | **Language** | **Weight** |
| --- | --- | --- | --- | --- |
| BAAI/bge-visualized-base-en-v1.5 | 768 | [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | English | [🤗 HF link](https://huggingface.co/BAAI/bge-visualized/blob/main/Visualized_base_en_v1.5.pth) |
| BAAI/bge-visualized-m3 | 1024 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | Multilingual | [🤗 HF link](https://huggingface.co/BAAI/bge-visualized/blob/main/Visualized_m3.pth) |
### Data
We have generated a hybrid multi-modal dataset consisting of over 500,000 instances for multi-modal training (Stage-2 training in our paper). You can download our dataset from this [🤗 HF Link](https://huggingface.co/datasets/JUNJIE99/VISTA_S2).
Process the image compression package with the following commands:
```bash
cat images.tar.part* > images.tar
tar -xvf images.tar
```
If you obtain the following directory structure. You can then use the annotation information (json files) for your own training:
```
images
|__coco
|__edit_image
```
## Usage
### Installation:
#### Install FlagEmbedding:
```
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding
pip install -e .
```
#### Another Core Packages:
```
pip install torchvision timm einops ftfy
```
You don't need to install `xformer` and `apex`. They are not essential for inference and can often cause issues.
### Generate Embedding for Multi-Modal Data:
Visualized-BGE provides the versatility to encode multi-modal data in a variety of formats, whether it's purely text, solely image-based, or a combination of both.
> **Note:** Please download the model weight file ([bge-visualized-base-en-v1.5](https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_base_en_v1.5.pth?download=true), [bge-visualized-m3](https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_m3.pth?download=true)) in advance and pass the path to the `model_weight` parameter.
- Composed Image Retrieval
``` python
####### Use Visualized BGE doing composed image retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
model = Visualized_BGE(model_name_bge = "BAAI/bge-base-en-v1.5", model_weight="path: Visualized_base_en_v1.5.pth")
model.eval()
with torch.no_grad():
query_emb = model.encode(image="./imgs/cir_query.png", text="Make the background dark, as if the camera has taken the photo at night")
candi_emb_1 = model.encode(image="./imgs/cir_candi_1.png")
candi_emb_2 = model.encode(image="./imgs/cir_candi_2.png")
sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
print(sim_1, sim_2) # tensor([[0.8750]]) tensor([[0.7816]])
```
- Multi-Modal Knowledge Retrieval
``` python
####### Use Visualized BGE doing multi-modal knowledge retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
model = Visualized_BGE(model_name_bge = "BAAI/bge-base-en-v1.5", model_weight="path: Visualized_base_en_v1.5.pth")
model.eval()
with torch.no_grad():
query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")
sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.6415]])
```
- Multilingual Multi-Modal Retrieval
``` python
##### Use M3 doing Multilingual Multi-Modal Retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
model = Visualized_BGE(model_name_bge = "BAAI/bge-m3", model_weight="path: Visualized_m3.pth")
model.eval()
with torch.no_grad():
query_emb = model.encode(image="./imgs/cir_query.png", text="一匹马牵着这辆车")
candi_emb_1 = model.encode(image="./imgs/cir_candi_1.png")
candi_emb_2 = model.encode(image="./imgs/cir_candi_2.png")
sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
print(sim_1, sim_2) # tensor([[0.7026]]) tensor([[0.8075]])
```
## Evaluation Result
Visualized BGE delivers outstanding zero-shot performance across multiple hybrid modal retrieval tasks. It can also serve as a base model for downstream fine-tuning for hybrid modal retrieval tasks.
#### Zero-shot Performance
- Statistical information of the zero-shot multi-modal retrieval benchmark datasets. During the zero-shot evaluation, we utilize the queries from the validation or test set of each dataset to perform retrieval assessments within the entire corpus of the respective dataset.
![Statistical information for the zero-shot multi-modal retrieval benchmark datasets.](./imgs/zs-benchmark.png)
- Zero-shot evaluation results with Recall@5 on various hybrid multi-modal retrieval benchmarks. The -MM notation indicates baseline models that have undergone multi-modal training on our generated data.
![Zero-shot evaluation results with Recall@5 on various hybrid multi-modal retrieval benchmarks.](./imgs/zs-performance.png)
#### Fine-tuning on Downstream Tasks
- Supervised fine-tuning performance on the WebQA dataset. All retrievals are performed on the entire deduplicated corpus.
![image.png](./imgs/SFT-WebQA.png)
- Supervised fine-tuning performance on the CIRR test set.
![image.png](./imgs/SFT-CIRR.png)
- Supervised fine-tuning performance on the ReMuQ test set.
![image.png](./imgs/SFT-ReMuQ.png)
## FAQ
**Q1: Can Visualized BGE be used for cross-modal retrieval (text to image)?**
A1: While it is technically possible, it's not the recommended use case. Our model focus on augmenting hybrid modal retrieval tasks with visual capabilities.
## Acknowledgement
The image token embedding model in this project is built upon the foundations laid by [EVA-CLIP](https://github.com/baaivision/EVA/tree/master/EVA-CLIP).
## Citation
If you find this repository useful, please consider giving a star ⭐ and citation
```
@article{zhou2024vista,
title={VISTA: Visualized Text Embedding For Universal Multi-Modal Retrieval},
author={Zhou, Junjie and Liu, Zheng and Xiao, Shitao and Zhao, Bo and Xiong, Yongping},
journal={arXiv preprint arXiv:2406.04292},
year={2024}
}
```
from .modeling import Visualized_BGE
\ No newline at end of file
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_eva_vision_and_transforms
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .transform import image_transform
\ No newline at end of file
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
# --------------------------------------------------------
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import os
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
except:
from timm.layers import drop_path, to_2tuple, trunc_normal_
from .transformer import PatchDropout
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
if os.getenv('ENV_TYPE') == 'deepspeed':
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except:
from torch.utils.checkpoint import checkpoint
else:
from torch.utils.checkpoint import checkpoint
try:
import xformers.ops as xops
except ImportError:
xops = None
# print("Please 'pip install xformers'")
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
drop=0.,
subln=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.ffn_ln(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLU(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
norm_layer=nn.LayerNorm, subln=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.w3 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = self.act(x1) * x2
x = self.ffn_ln(hidden)
x = self.w3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.subln = subln
if self.subln:
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
else:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
# self.proj = nn.Linear(all_head_dim, all_head_dim)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
self.rope = rope
def forward(self, x, rel_pos_bias=None, attn_mask=None):
B, N, C = x.shape
if self.subln:
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
else:
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
q, k, v = qkv[0], qkv[1], qkv[2]
if self.rope:
# slightly fast impl
q_t = q[:, :, 1:, :]
ro_q_t = self.rope(q_t)
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
k_t = k[:, :, 1:, :]
ro_k_t = self.rope(k_t)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
if xops is not None:
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = xops.memory_efficient_attention(
q, k, v,
p=self.xattn_drop,
scale=self.scale,
)
x = x.reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias.type_as(attn)
if attn_mask is not None:
attn_mask = attn_mask.bool()
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
subln=False, naiveswiglu=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if naiveswiglu:
self.mlp = SwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
subln=subln,
drop=drop
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
self.postnorm = postnorm
def forward(self, x, rel_pos_bias=None, attn_mask=None):
if self.gamma_1 is None:
if self.postnorm:
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
if self.postnorm:
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # [10, 3, 224, 224] -> [10, 196, 768]
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class EVAVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
if rope:
half_head_dim = embed_dim // num_heads // 2
hw_seq_len = img_size // patch_size
self.rope = VisionRotaryEmbeddingFast(
dim=half_head_dim,
pt_seq_len=pt_hw_seq_len,
ft_seq_len=hw_seq_len if intp_freq else None,
# patch_dropout=patch_dropout
)
else:
self.rope = None
self.naiveswiglu = naiveswiglu
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.grad_checkpointing = grad_checkpointing
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
if self.naiveswiglu:
rescale(layer.mlp.w3.weight.data, layer_id + 1)
else:
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def get_cast_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, return_all_features=False):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
if os.getenv('RoPE') == '1':
if self.training and not isinstance(self.patch_dropout, nn.Identity):
x, patch_indices_keep = self.patch_dropout(x)
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
else:
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
x = self.patch_dropout(x)
else:
x = self.patch_dropout(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
if self.grad_checkpointing:
# x = checkpoint(blk, x, (rel_pos_bias,))
x = checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias=rel_pos_bias)
if not return_all_features:
x = self.norm(x)
if self.fc_norm is not None:
return self.fc_norm(x.mean(1))
else:
return x[:, 0]
return x
def forward(self, x, return_all_features=True):
if return_all_features:
return self.forward_features(x, return_all_features)
x = self.forward_features(x)
x = self.head(x)
return x
import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union, Dict, Any
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
get_cast_dtype
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform
from .tokenizer import HFTokenizer, tokenize
from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, "r", encoding="utf8") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def get_tokenizer(model_name):
config = get_model_config(model_name)
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
return tokenizer
# loading openai CLIP weights when is_openai=True for training
def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
if is_openai:
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
state_dict = model.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
for mk in model_key.split('|'):
if isinstance(checkpoint, dict) and mk in checkpoint:
state_dict = checkpoint[mk]
break
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for k in skip_list:
if k in list(state_dict.keys()):
logging.info(f"Removing key {k} from pretrained checkpoint")
del state_dict[k]
if os.getenv('RoPE') == '1':
for k in list(state_dict.keys()):
if 'freqs_cos' in k or 'freqs_sin' in k:
del state_dict[k]
return state_dict
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
state_dict['logit_scale'] = state_dict['text.logit_scale']
del state_dict['text.logit_scale']
# resize_clip_pos_embed for CLIP and open CLIP
if 'visual.positional_embedding' in state_dict:
resize_clip_pos_embed(state_dict, model)
# specified to eva_vit_model
elif 'visual.pos_embed' in state_dict:
resize_evaclip_pos_embed(state_dict, model)
# resize_clip_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return incompatible_keys
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if not k.startswith('visual.'):
del state_dict[k]
for k in list(state_dict.keys()):
if k.startswith('visual.'):
new_k = k[7:]
state_dict[new_k] = state_dict[k]
del state_dict[k]
return state_dict
def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if k.startswith('visual.'):
del state_dict[k]
return state_dict
def get_pretrained_tag(pretrained_model):
pretrained_model = pretrained_model.lower()
if "laion" in pretrained_model or "open_clip" in pretrained_model:
return "open_clip"
elif "openai" in pretrained_model:
return "clip"
elif "eva" in pretrained_model and "clip" in pretrained_model:
return "eva_clip"
else:
return "other"
def load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=True,
visual_model=None,
text_model=None,
model_key="model|module|state_dict",
skip_list=[]):
visual_tag = get_pretrained_tag(visual_model)
text_tag = get_pretrained_tag(text_model)
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
visual_incompatible_keys, text_incompatible_keys = None, None
if visual_checkpoint_path:
if visual_tag == "eva_clip" or visual_tag == "open_clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
elif visual_tag == "clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
# resize_clip_pos_embed for CLIP and open CLIP
if 'positional_embedding' in visual_state_dict:
resize_visual_pos_embed(visual_state_dict, model)
# specified to EVA model
elif 'pos_embed' in visual_state_dict:
resize_eva_pos_embed(visual_state_dict, model)
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
if text_checkpoint_path:
if text_tag == "eva_clip" or text_tag == "open_clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
elif text_tag == "clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
return visual_incompatible_keys, text_incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = '',
pretrained_text: str = '',
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
is_only_visual: bool = False,
is_only_text: bool = False,
):
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
else:
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if 'rope' in model_cfg.get('vision_cfg', {}):
if model_cfg['vision_cfg']['rope']:
os.environ['RoPE'] = "1"
else:
os.environ['RoPE'] = "0"
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
cast_dtype = get_cast_dtype(precision)
custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
if custom_clip:
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype, is_only_visual=is_only_visual, is_only_text=is_only_text)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
print("Not CustomCLIP: If you have set building only visual or text tower, you may still get a complete CLIP model.")
pretrained_cfg = {}
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model,
checkpoint_path,
model_key="model|module|state_dict",
strict=False
)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
else:
visual_checkpoint_path = ''
text_checkpoint_path = ''
if pretrained_image:
pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
elif pretrained_image_cfg:
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_image):
visual_checkpoint_path = pretrained_image
else:
logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
if pretrained_text:
pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
if pretrained_image_cfg:
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_text):
text_checkpoint_path = pretrained_text
else:
logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
if visual_checkpoint_path:
logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
if text_checkpoint_path:
logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
if visual_checkpoint_path or text_checkpoint_path:
load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=False,
visual_model=pretrained_visual_model,
text_model=pretrained_text_model,
model_key="model|module|state_dict",
skip_list=skip_list
)
if "fp16" in precision or "bf16" in precision:
logging.info(f'convert precision to {precision}')
model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
model.to(device=device)
# set image / mean metadata from pretrained_cfg if available, or use default
if not is_only_text:
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
if jit:
model = torch.jit.script(model)
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = '',
pretrained_text: str = '',
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_hf=pretrained_hf,
pretrained_visual_model=pretrained_visual_model,
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
skip_list=skip_list,
)
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
return model, preprocess_train, preprocess_val
def create_eva_vision_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = '',
pretrained_text: str = '',
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_hf=pretrained_hf,
pretrained_visual_model=pretrained_visual_model,
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
skip_list=skip_list,
is_only_visual=True, # only use visual tower
)
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: str,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
is_frozen: bool = False,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
f' Use open_clip.list_pretrained() to find one.')
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
cache_dir=cache_dir,
)
if is_frozen:
for param in model.parameters():
param.requires_grad = False
if not return_transform:
return model
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
return model, preprocess
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
}
}
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
if (self.use_pooler_output and
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
(x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
def __init__(
self,
model_name_or_path: str,
output_dim: int,
tokenizer_name: str = None,
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True,
masked_language_modeling: bool = False):
super().__init__()
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = (pooler_type == "cls_pooler")
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
if masked_language_modeling:
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
AutoModelForMaskedLM.from_config, self.config)
else:
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
if masked_language_modeling:
self.transformer = AutoModelForMaskedLM.from_config(config)
else:
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
# attn_mask = (x != self.config.pad_token_id).long()
# out = self.transformer(
# input_ids=x,
# attention_mask=attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# )
# pooled_out = self.pooler(out, attn_mask)
# return self.itm_proj(pooled_out)
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
labels = input_ids.clone()
attn_mask = (input_ids != self.config.pad_token_id).long()
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
probability_matrix = torch.full(labels.shape, mlm_probability)
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
probability_matrix = probability_matrix)
mlm_output = self.transformer(input_ids,
attention_mask = attn_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
labels = labels,
)
return mlm_output.loss
# mlm_output = self.transformer(input_ids,
# attention_mask = attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# return_dict = True,
# ).last_hidden_state
# logits = self.mlm_proj(mlm_output)
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
# labels = labels[:, 1:].contiguous().view(-1)
# mlm_loss = F.cross_entropy(
# logits,
# labels,
# # label_smoothing=0.1,
# )
# return mlm_loss
def forward(self, x:TensorType) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
return self.proj(pooled_out)
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def get_num_layers(self):
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
return len(layer_list)
def init_parameters(self):
pass
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from timm.loss import LabelSmoothingCrossEntropy
def gather_features(
image_features,
text_features,
local_loss=False,
gather_with_grad=False,
rank=0,
world_size=1,
use_horovod=False
):
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
# all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
# all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
class ClipLoss(nn.Module):
def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
smoothing=0.,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
# cache state
self.prev_num_logits = 0
self.labels = {}
def forward(self, image_features, text_features, logit_scale=1.):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
# calculated ground-truth and cache if enabled
num_logits = logits_per_image.shape[0]
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
if self.label_smoothing_cross_entropy:
total_loss = (
self.label_smoothing_cross_entropy(logits_per_image, labels) +
self.label_smoothing_cross_entropy(logits_per_text, labels)
) / 2
else:
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
acc = None
i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
acc = {"i2t": i2t_acc, "t2i": t2i_acc}
return total_loss, acc
\ No newline at end of file
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
try:
from .hf_model import HFTextEncoder
except:
HFTextEncoder = None
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
# try:
# from apex.normalization import FusedLayerNorm
# except:
FusedLayerNorm = LayerNorm
# print("Please 'pip install apex'")
try:
import xformers.ops as xops
except ImportError:
xops = None
# print("Please 'pip install xformers'")
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
drop_path_rate: Optional[float] = None # drop path rate
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
qkv_bias: bool = True
fusedLN: bool = False
xattn: bool = False
postnorm: bool = False
rope: bool = False
pt_hw_seq_len: int = 16 # 224/14
intp_freq: bool = False
naiveswiglu: bool = False
subln: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
masked_language_modeling: bool = False
fusedLN: bool = False
xattn: bool = False
attn_mask: bool = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.eva_model_name:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNorm
visual = EVAVisionTransformer(
img_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
num_classes=embed_dim,
use_mean_pooling=vision_cfg.global_average_pool, #False
init_values=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
embed_dim=vision_cfg.width,
depth=vision_cfg.layers,
num_heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
qkv_bias=vision_cfg.qkv_bias,
drop_path_rate=vision_cfg.drop_path_rate,
norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
xattn=vision_cfg.xattn,
rope=vision_cfg.rope,
postnorm=vision_cfg.postnorm,
pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
intp_freq= vision_cfg.intp_freq,
naiveswiglu= vision_cfg.naiveswiglu,
subln= vision_cfg.subln
)
elif vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
embed_dim=embed_dim,
image_size=vision_cfg.image_size
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
global_average_pool=vision_cfg.global_average_pool,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
tokenizer_name=text_cfg.hf_tokenizer_name,
proj=text_cfg.proj,
pooler_type=text_cfg.pooler_type,
masked_language_modeling=text_cfg.masked_language_modeling
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
xattn=text_cfg.xattn,
attn_mask=text_cfg.attn_mask,
)
return text
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {'logit_scale'}
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
class CustomCLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
itm_task: bool = False,
is_only_visual: bool = False,
is_only_text: bool = False,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) #可学习参数
if is_only_visual:
self.text = None
if is_only_text:
self.visual = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
if self.text is not None:
self.text.set_grad_checkpointing(enable)
@torch.jit.ignore
def no_weight_decay(self):
return {'logit_scale'}
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
if self.visual is not None:
image_features = self.encode_image(image, normalize=True)
else:
image_features = None
if self.text is not None:
text_features = self.encode_text(text, normalize=True)
else:
text_features = None
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr, None)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, nn.Parameter):
l.data = l.data.to(dtype)
for name in ["text_projection", "proj"]:
if hasattr(l, name) and isinstance(l, nn.Parameter):
attr = getattr(l, name, None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
'logit_scale'
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16,
"eva_model_name": "eva-clip-b-16",
"ls_init_value": 0.1,
"drop_path_rate": 0.0
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0.4,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"head_width": 64,
"patch_size": 16,
"mlp_ratio": 2.6667,
"eva_model_name": "eva-clip-b-16-X",
"drop_path_rate": 0.0,
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true,
"patch_dropout": 0.5
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"xattn": true,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment