Commit 6f43e8fa authored by mashun1's avatar mashun1
Browse files

open_clip

parents
Pipeline #1689 canceled with stages
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
\ No newline at end of file
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "xlm-roberta-base",
"hf_tokenizer_name": "xlm-roberta-base",
"hf_pooler_type": "mean_pooler"
}
}
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "xlm-roberta-large",
"hf_tokenizer_name": "xlm-roberta-large",
"hf_pooler_type": "mean_pooler"
}
}
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
from open_clip.utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
import warnings
from typing import List, Optional, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
__all__ = ["list_openai_models", "load_openai_model"]
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_models_by_tag('openai')
def load_openai_model(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
cache_dir: Optional[str] = None,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
cache_dir : Optional[str]
The directory to cache the downloaded model weights
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if precision is None:
precision = 'fp32' if device == 'cpu' else 'fp16'
if get_pretrained_url(name, 'openai'):
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
state_dict = torch.load(model_path, map_location="cpu")
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype = get_cast_dtype(precision)
try:
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model = model.to(device)
# FIXME support pure fp16/bf16 precision modes
if precision != 'fp16':
model.float()
if precision == 'bf16':
# for bf16, convert back to low-precision
convert_weights_to_lp(model, dtype=torch.bfloat16)
# add mean / std attributes for consistency with OpenCLIP models
model.visual.image_mean = OPENAI_DATASET_MEAN
model.visual.image_std = OPENAI_DATASET_STD
return model
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
import numpy as np
import torch
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
import hashlib
import os
import urllib
import warnings
from functools import partial
from typing import Dict, Union
from tqdm import tqdm
from .constants import (
IMAGENET_MEAN,
IMAGENET_STD,
INCEPTION_MEAN,
INCEPTION_STD,
OPENAI_DATASET_MEAN,
OPENAI_DATASET_STD,
)
from .version import __version__
try:
from huggingface_hub import hf_hub_download
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url='', hf_hub='', **kwargs):
# OpenAI / OpenCLIP defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': OPENAI_DATASET_MEAN,
'std': OPENAI_DATASET_STD,
'interpolation': 'bicubic',
'resize_mode': 'shortest',
**kwargs,
}
def _slpcfg(url='', hf_hub='', **kwargs):
# SiGLIP defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': INCEPTION_MEAN,
'std': INCEPTION_STD,
'interpolation': 'bicubic',
'resize_mode': 'squash',
**kwargs,
}
def _apcfg(url='', hf_hub='', **kwargs):
# CLIPA defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': IMAGENET_MEAN,
'std': IMAGENET_STD,
'interpolation': 'bilinear',
'resize_mode': 'squash',
**kwargs,
}
def _mccfg(url='', hf_hub='', **kwargs):
# MobileCLIP
return {
'url': url,
'hf_hub': hf_hub,
'mean': (0., 0., 0.),
'std': (1., 1., 1.),
'interpolation': 'bilinear',
'resize_mode': 'shortest',
**kwargs,
}
_RN50 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN50_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN101 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN101_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN50x4 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
)
_RN50x16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
)
_RN50x64 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
)
_VITB32 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
laion2b_e16=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'),
# DataComp-M models
datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),
commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),
commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),
commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),
commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),
commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),
commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),
# DataComp-S models
datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),
commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),
commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),
commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),
commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),
commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),
commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),
)
_VITB32_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
metaclip_400m=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"),
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"),
)
_VITB32_256 = dict(
datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'),
)
_VITB16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'),
# DataComp-L models
datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),
commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),
commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),
commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),
commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),
commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),
commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),
# DFN
dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/')
)
_VITB16_quickgelu = dict(
metaclip_400m=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"),
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
)
_VITL14 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
laion2b_s32b_b82k=_pcfg(
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
mean=INCEPTION_MEAN, std=INCEPTION_STD),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),
commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),
commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),
commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),
)
_VITL14_quickgelu = dict(
metaclip_400m=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"),
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"),
dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'),
)
_VITL14_336 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
)
_VITH14_quickgelu = dict(
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"),
dfn5b=_pcfg(
hf_hub='apple/DFN5B-CLIP-ViT-H-14/',
interpolation="bicubic",
resize_mode="squash"
),
)
_VITH14_378_quickgelu = dict(
dfn5b=_pcfg(
hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/',
interpolation="bicubic",
resize_mode="squash"
),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
)
_robertaViTB32 = dict(
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
)
_xlmRobertaBaseViTB32 = dict(
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
)
_xlmRobertaLargeFrozenViTH14 = dict(
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
)
_convnext_base = dict(
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
)
_convnext_base_w = dict(
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
)
_convnext_base_w_320 = dict(
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
)
_convnext_large_d = dict(
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
)
_convnext_large_d_320 = dict(
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
)
_convnext_xxlarge = dict(
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
)
_coca_VITB32 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
)
_coca_VITL14 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
)
_PRETRAINED = {
"RN50": _RN50,
"RN50-quickgelu": _RN50_quickgelu,
"RN101": _RN101,
"RN101-quickgelu": _RN101_quickgelu,
"RN50x4": _RN50x4,
"RN50x16": _RN50x16,
"RN50x64": _RN50x64,
"ViT-B-32": _VITB32,
"ViT-B-32-256": _VITB32_256,
"ViT-B-32-quickgelu": _VITB32_quickgelu,
"ViT-B-16": _VITB16,
"ViT-B-16-quickgelu": _VITB16_quickgelu,
"ViT-B-16-plus-240": _VITB16_PLUS_240,
"ViT-L-14": _VITL14,
"ViT-L-14-quickgelu": _VITL14_quickgelu,
"ViT-L-14-336": _VITL14_336,
"ViT-H-14": _VITH14,
"ViT-H-14-quickgelu": _VITH14_quickgelu,
"ViT-H-14-378-quickgelu": _VITH14_378_quickgelu,
"ViT-g-14": _VITg14,
"ViT-bigG-14": _VITbigG14,
"roberta-ViT-B-32": _robertaViTB32,
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
"convnext_base": _convnext_base,
"convnext_base_w": _convnext_base_w,
"convnext_base_w_320": _convnext_base_w_320,
"convnext_large_d": _convnext_large_d,
"convnext_large_d_320": _convnext_large_d_320,
"convnext_xxlarge": _convnext_xxlarge,
"coca_ViT-B-32": _coca_VITB32,
"coca_ViT-L-14": _coca_VITL14,
"EVA01-g-14": dict(
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt
laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),
),
"EVA01-g-14-plus": dict(
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt
merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),
),
"EVA02-B-16": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt
merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),
),
"EVA02-L-14": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt
merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),
),
"EVA02-L-14-336": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt
merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),
),
"EVA02-E-14": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt
laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),
),
"EVA02-E-14-plus": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt
laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),
),
"ViT-B-16-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'),
),
"ViT-B-16-SigLIP-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'),
),
"ViT-B-16-SigLIP-i18n-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'),
),
"ViT-B-16-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'),
),
"ViT-B-16-SigLIP-512": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'),
),
"ViT-L-16-SigLIP-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'),
),
"ViT-L-16-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'),
),
"ViT-SO400M-14-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),
),
"ViT-SO400M-14-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),
),
"ViT-L-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'),
),
"ViT-L-14-CLIPA-336": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'),
),
"ViT-H-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'),
),
"ViT-H-14-CLIPA-336": dict(
laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'),
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'),
),
"ViT-bigG-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'),
),
"ViT-bigG-14-CLIPA-336": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'),
),
"nllb-clip-base": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),
),
"nllb-clip-large": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),
),
"nllb-clip-base-siglip": dict(
v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'),
mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'),
),
"nllb-clip-large-siglip": dict(
v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),
mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),
),
"MobileCLIP-S1": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')),
"MobileCLIP-S2": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')),
"MobileCLIP-B": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'),
datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'),
),
"ViTamin-S": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'),
),
"ViTamin-S-LTT": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'),
),
"ViTamin-B": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'),
),
"ViTamin-B-LTT": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'),
),
"ViTamin-L": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'),
),
"ViTamin-L-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'),
),
"ViTamin-L-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'),
),
"ViTamin-L-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'),
),
"ViTamin-L2": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'),
),
"ViTamin-L2-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'),
),
"ViTamin-L2-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'),
),
"ViTamin-L2-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'),
),
"ViTamin-XL-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'),
),
"ViTamin-XL-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'),
),
"ViTamin-XL-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'),
),
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace('-', '_')
def list_pretrained(as_str: bool = False):
""" returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
def list_pretrained_models_by_tag(tag: str):
""" return all models having the specified pretrain tag """
models = []
tag = _clean_tag(tag)
for k in _PRETRAINED.keys():
if tag in _PRETRAINED[k]:
models.append(k)
return models
def list_pretrained_tags_by_model(model: str):
""" return all pretrain tags for the specified model architecture """
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def is_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return False
return _clean_tag(tag) in _PRETRAINED[model]
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def get_pretrained_url(model: str, tag: str):
cfg = get_pretrained_cfg(model, _clean_tag(tag))
return cfg.get('url', '')
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if 'openaipublic' in url:
expected_sha256 = url.split("/")[-2]
elif 'mlfoundations' in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ''
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = 'open_clip_pytorch_model.bin',
revision=None,
cache_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
):
target = ''
if not cfg:
return target
download_url = cfg.get('url', '')
download_hf_hub = cfg.get('hf_hub', '')
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ''
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target
import argparse
import json
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple, Union
import torch
try:
from huggingface_hub import (
create_repo,
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
repo_type_and_id_from_hf_id,
upload_folder,
list_repo_files,
)
from huggingface_hub.utils import EntryNotFoundError
_has_hf_hub = True
except ImportError:
_has_hf_hub = False
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
from .tokenizer import HFTokenizer
# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
HF_CONFIG_NAME = 'open_clip_config.json'
def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict]
):
preprocess_cfg = {
'mean': model.visual.image_mean,
'std': model.visual.image_std,
}
other_pp = getattr(model.visual, 'preprocess_cfg', {})
if 'interpolation' in other_pp:
preprocess_cfg['interpolation'] = other_pp['interpolation']
if 'resize_mode' in other_pp:
preprocess_cfg['resize_mode'] = other_pp['resize_mode']
hf_config = {
'model_cfg': model_config,
'preprocess_cfg': preprocess_cfg,
}
with config_path.open('w') as f:
json.dump(hf_config, f, indent=2)
def save_for_hf(
model,
tokenizer: HFTokenizer,
model_config: dict,
save_directory: str,
safe_serialization: Union[bool, str] = 'both',
skip_weights : bool = False,
):
config_filename = HF_CONFIG_NAME
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
if not skip_weights:
tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
tokenizer.save_pretrained(save_directory)
config_path = save_directory / config_filename
save_config_for_hf(model, config_path, model_config=model_config)
def push_to_hf_hub(
model,
tokenizer,
model_config: Optional[dict],
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_card: Optional[dict] = None,
safe_serialization: Union[bool, str] = 'both',
):
if not isinstance(tokenizer, HFTokenizer):
# FIXME this makes it awkward to push models with new tokenizers, come up with better soln.
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
# Infer complete repo_id from repo_url
# Can be different from the input `repo_id` if repo_owner was implicit
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
repo_id = f"{repo_owner}/{repo_name}"
# Check if repo already exists and determine what needs updating
repo_exists = False
repo_files = {}
try:
repo_files = set(list_repo_files(repo_id))
repo_exists = True
except Exception as e:
print('Repo does not exist', e)
try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False
# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(
model,
tokenizer=tokenizer,
model_config=model_config,
save_directory=tmpdir,
safe_serialization=safe_serialization,
)
# Add readme if it does not exist
if not has_readme:
model_card = model_card or {}
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = generate_readme(model_card, model_name)
readme_path.write_text(readme_text)
# Upload model and return
return upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
create_pr=create_pr,
commit_message=commit_message,
)
def push_pretrained_to_hf_hub(
model_name,
pretrained: str,
repo_id: str,
precision: str = 'fp32',
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_card: Optional[dict] = None,
hf_tokenizer_self: bool = False,
**kwargs,
):
model, preprocess_eval = create_model_from_pretrained(
model_name,
pretrained=pretrained,
precision=precision,
image_mean=image_mean,
image_std=image_std,
image_interpolation=image_interpolation,
image_resize_mode=image_resize_mode,
**kwargs,
)
model_config = get_model_config(model_name)
if pretrained == 'openai':
model_config['quick_gelu'] = True
assert model_config
tokenizer = get_tokenizer(model_name)
if hf_tokenizer_self:
# make hf tokenizer config in the uploaded model point to self instead of original location
model_config['text']['hf_tokenizer_name'] = repo_id
push_to_hf_hub(
model=model,
tokenizer=tokenizer,
model_config=model_config,
repo_id=repo_id,
commit_message=commit_message,
token=token,
revision=revision,
private=private,
create_pr=create_pr,
model_card=model_card,
safe_serialization='both',
)
def generate_readme(model_card: dict, model_name: str):
tags = model_card.pop('tags', ('clip',))
pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification')
readme_text = "---\n"
if tags:
readme_text += "tags:\n"
for t in tags:
readme_text += f"- {t}\n"
readme_text += "library_name: open_clip\n"
readme_text += f"pipeline_tag: {pipeline_tag}\n"
readme_text += f"license: {model_card.get('license', 'mit')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n'
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
if isinstance(v, (list, tuple)):
readme_text += f"- **{k}:**\n"
for vi in v:
readme_text += f" - {vi}\n"
elif isinstance(v, dict):
readme_text += f"- **{k}:**\n"
for ki, vi in v.items():
readme_text += f" - {ki}: {vi}\n"
else:
readme_text += f"- **{k}:** {v}\n"
if 'usage' in model_card:
readme_text += f"\n## Model Usage\n"
readme_text += model_card['usage']
readme_text += '\n'
if 'comparison' in model_card:
readme_text += f"\n## Model Comparison\n"
readme_text += model_card['comparison']
readme_text += '\n'
if 'citation' in model_card:
readme_text += f"\n## Citation\n"
if not isinstance(model_card['citation'], (list, tuple)):
citations = [model_card['citation']]
else:
citations = model_card['citation']
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
parser.add_argument(
"--model", type=str, help="Name of the model to use.",
)
parser.add_argument(
"--pretrained", type=str,
help="Use a pretrained CLIP model weights with the specified tag or file path.",
)
parser.add_argument(
"--repo-id", type=str,
help="Destination HF Hub repo-id ie 'organization/model_id'.",
)
parser.add_argument(
"--precision", type=str, default='fp32',
)
parser.add_argument(
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override default image mean value of dataset')
parser.add_argument(
'--image-std', type=float, nargs='+', default=None, metavar='STD',
help='Override default image std deviation of of dataset')
parser.add_argument(
'--image-interpolation',
default=None, type=str, choices=['bicubic', 'bilinear', 'random'],
help="image resize interpolation"
)
parser.add_argument(
'--image-resize-mode',
default=None, type=str, choices=['shortest', 'longest', 'squash'],
help="image resize mode during inference"
)
parser.add_argument(
"--hf-tokenizer-self",
default=False,
action="store_true",
help="make hf_tokenizer_name point in uploaded config point to itself"
)
args = parser.parse_args()
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
# FIXME add support to pass model_card json / template from file via cmd line
push_pretrained_to_hf_hub(
args.model,
args.pretrained,
args.repo_id,
precision=args.precision,
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
image_std=args.image_std,
image_interpolation=args.image_interpolation,
image_resize_mode=args.image_resize_mode,
)
print(f'{args.model} saved.')
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from .utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
""" timm model adapter
"""
def __init__(
self,
model_name,
embed_dim,
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
drop_path=None,
patch_drop=None,
pretrained=False,
):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
# setup kwargs that may not be common across all models
timm_kwargs = {}
if drop_path is not None:
timm_kwargs['drop_path_rate'] = drop_path
if patch_drop is not None:
timm_kwargs['patch_drop_rate'] = patch_drop
custom_pool = pool in ('abs_attn', 'rot_attn')
if proj:
assert proj in ("linear", "mlp", "none")
extra_proj = proj in ("linear", "mlp")
if not extra_proj and not custom_pool:
# use network classifier head as projection if no proj specified and no custom pooling used
# if projection is explicitly set to "none" will be pass through from network trunk
proj_dim = 0 if proj == 'none' else embed_dim
self.trunk = timm.create_model(
model_name,
num_classes=proj_dim,
global_pool=pool,
pretrained=pretrained,
**timm_kwargs,
)
prev_chs = embed_dim
else:
self.trunk = timm.create_model(
model_name,
pretrained=pretrained,
**timm_kwargs,
)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if custom_pool:
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
# Add custom pooling to head
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x
""" CLIP tokenizer
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import gzip
import html
import os
import random
import string
from functools import lru_cache, partial
from typing import Callable, List, Optional, Union
import warnings
import ftfy
import numpy as np
import regex as re
import torch
# https://stackoverflow.com/q/62691279
os.environ["TOKENIZERS_PARALLELISM"] = "false"
_nltk_init = False
DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = " ".join(text.split())
text = text.strip()
return text
def _clean_canonicalize(x):
# basic, remove whitespace, remove punctuation, lower case
return canonicalize_text(basic_clean(x))
def _clean_lower(x):
# basic, remove whitespace, lower case
return whitespace_clean(basic_clean(x)).lower()
def _clean_whitespace(x):
# basic, remove whitespace
return whitespace_clean(basic_clean(x))
def get_clean_fn(type: str):
if type == 'canonicalize':
return _clean_canonicalize
elif type == 'lower':
return _clean_lower
elif type == 'whitespace':
return _clean_whitespace
else:
assert False, f"Invalid clean function ({type})."
def canonicalize_text(
text,
*,
keep_punctuation_exact_string=None,
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
):
"""Returns canonicalized `text` (lowercase and punctuation removed).
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
Args:
text: string to be canonicalized.
keep_punctuation_exact_string: If provided, then this exact string kept.
For example providing '{}' will keep any occurrences of '{}' (but will
still remove '{' and '}' that appear separately).
"""
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(trans_punctuation)
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(trans_punctuation)
text = text.lower()
text = " ".join(text.split())
return text.strip()
class SimpleTokenizer(object):
def __init__(
self,
bpe_path: str = default_bpe(),
additional_special_tokens: Optional[List[str]] = None,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'lower',
reduction_mask: str = ''
):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
special_tokens = ['<start_of_text>', '<end_of_text>']
if additional_special_tokens:
special_tokens += additional_special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t:t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
self.sot_token_id = self.all_special_ids[0]
self.eot_token_id = self.all_special_ids[1]
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = self.clean_fn(text)
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
""" Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, 'Please set a valid context length'
if self.reduction_fn is not None:
# use reduction strategy for tokenize if set, otherwise default to truncation below
return self.reduction_fn(
texts,
context_length=context_length,
sot_token_id=self.sot_token_id,
eot_token_id=self.eot_token_id,
encode_fn=self.encode,
)
all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = self.eot_token_id
result[i, :len(tokens)] = torch.tensor(tokens)
return result
_tokenizer = SimpleTokenizer()
def decode(output_ids: torch.Tensor):
output_ids = output_ids.cpu().numpy()
return _tokenizer.decode(output_ids)
def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
return _tokenizer(texts, context_length=context_length)
def random_mask_tokenize(
texts: Union[str, List[str]],
context_length: int,
sot_token_id: int,
eot_token_id: int,
encode_fn: Callable,
shuffle: bool = False,
):
all_tokens = [encode_fn(text) for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
tokens = torch.tensor(tokens)
num_tokens = len(tokens)
if num_tokens > context_length - 2: # 2 for sot and eot token
num_keep = context_length - 2
indices = torch.randperm(len(tokens))
indices = indices[:num_keep]
if not shuffle:
indices = indices.msort()
tokens = tokens[indices]
num_tokens = num_keep
result[i, 0] = sot_token_id
result[i, 1:num_tokens + 1] = tokens
result[i, num_tokens + 1] = eot_token_id
return result
def simple_mask_tokenize(
texts: Union[str, List[str]],
context_length: int,
sot_token_id: int,
eot_token_id: int,
encode_fn: Callable,
):
all_tokens = [encode_fn(text) for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
num_tokens = len(tokens)
if num_tokens > context_length - 2: # 2 for sot and eot token
num_keep = context_length - 2
start_index = random.randint(0, num_tokens - num_keep) # high is incl
tokens = tokens[start_index: start_index + num_keep]
tokens = [sot_token_id] + tokens + [eot_token_id]
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def syntax_mask_tokenize(
texts: Union[str, List[str]],
context_length: int,
sot_token_id: int,
eot_token_id: int,
encode_fn: Callable,
) -> torch.LongTensor:
""" Returns the tokenized representation of given input string(s).
Apply syntax masking before tokenize.
"""
import nltk
global _nltk_init
if not _nltk_init:
# run them for the first time
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
_nltk_init = True
def get_order(x):
if x.startswith('NN'):
return 1
elif x.startswith('JJ'):
return 2
elif x.startswith('VB'):
return 3
else:
return 4
# syntax masking
new_texts = []
for text in texts:
list_tokens = nltk.tokenize.word_tokenize(text)
pos_tags = nltk.pos_tag(list_tokens)
# sample the words by get_order method
order_list = [get_order(tag) for _, tag in pos_tags]
sorted_ids = np.argsort(np.array(order_list))
sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
new_text = ''
for token in sampled_tokens:
new_text = new_text + str(token) + ' '
new_text = new_text.strip()
new_texts.append(new_text)
texts = new_texts
all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
# still need first truncate because some words produces two tokens
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token_id
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def get_reduction_mask_fn(type: str):
""" Choose strategy for dropping (masking) tokens to achieve target context length"""
assert type in ('simple', 'random', 'shuffle', 'syntax')
if type == 'simple':
return simple_mask_tokenize # randomly select block [start:end]
elif type == 'random':
return random_mask_tokenize # randomly drop tokens (keep order)
elif type == 'shuffle':
return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
elif type == 'syntax':
return syntax_mask_tokenize # randomly drop prioritized by syntax
class HFTokenizer:
"""HuggingFace tokenizer wrapper"""
def __init__(
self,
tokenizer_name: str,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'whitespace',
strip_sep_token: bool = False,
language: Optional[str] = None,
**kwargs
):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs)
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
if language is not None:
self.set_language(language)
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, 'Please set a valid context length in class init or call.'
texts = [self.clean_fn(text) for text in texts]
input_ids = self.tokenizer.batch_encode_plus(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
).input_ids
if self.strip_sep_token:
input_ids = torch.where(
input_ids == self.tokenizer.sep_token_id,
torch.zeros_like(input_ids),
input_ids,
)
return input_ids
def set_language(self, src_lang):
if hasattr(self, 'set_lang_fn'):
self.set_lang_fn(src_lang)
else:
warnings.warn('Cannot set language for the tokenizer.')
class SigLipTokenizer:
"""HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
"""
VOCAB_FILES = {
# english, vocab_size=32_000
"c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
# used in multilingual models (mT5, PaLI), vocab_size=250_000
"mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
}
def __init__(
self,
tokenizer_name: str,
context_length: Optional[int] = 64,
):
from transformers import T5TokenizerFast
if tokenizer_name in self.VOCAB_FILES:
# FIXME temporary hack?
import tempfile
import fsspec
vocab_file = self.VOCAB_FILES[tokenizer_name]
with tempfile.NamedTemporaryFile('wb') as dst:
with fsspec.open(vocab_file, 'rb') as src:
dst.write(src.read())
self.tokenizer = T5TokenizerFast(dst.name, legacy=False)
else:
self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False)
self.tokenizer.pad_token_id = 1
self.tokenizer.eos_token_id = 1
self.context_length = context_length
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, 'Please set a valid context length in class init or call.'
texts = [canonicalize_text(basic_clean(text)) for text in texts]
output = self.tokenizer(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
)
return output.input_ids
import numbers
import random
import warnings
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop, ColorJitter, Grayscale
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .utils import to_2tuple
@dataclass
class PreprocessCfg:
size: Union[int, Tuple[int, int]] = 224
mode: str = 'RGB'
mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
std: Tuple[float, ...] = OPENAI_DATASET_STD
interpolation: str = 'bicubic'
resize_mode: str = 'shortest'
fill_color: int = 0
def __post_init__(self):
assert self.mode in ('RGB',)
@property
def num_channels(self):
return 3
@property
def input_size(self):
return (self.num_channels,) + to_2tuple(self.size)
_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
def merge_preprocess_dict(
base: Union[PreprocessCfg, Dict],
overlay: Dict,
):
""" Merge overlay key-value pairs on top of base preprocess cfg or dict.
Input dicts are filtered based on PreprocessCfg fields.
"""
if isinstance(base, PreprocessCfg):
base_clean = asdict(base)
else:
base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
if overlay:
overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}
base_clean.update(overlay_clean)
return base_clean
def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
return merge_preprocess_dict(base, kwargs)
@dataclass
class AugmentationCfg:
scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None
re_prob: Optional[float] = None
re_count: Optional[int] = None
use_timm: bool = False
# params for simclr_jitter_gray
color_jitter_prob: float = None
gray_scale_prob: float = None
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
class ResizeKeepRatio:
""" Resize and Keep Ratio
Copy & paste from `timm`
"""
def __init__(
self,
size,
longest=0.,
interpolation=InterpolationMode.BICUBIC,
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11)
):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
self.interpolation = interpolation
self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
self.random_scale_prob = random_scale_prob
self.random_scale_range = random_scale_range
self.random_aspect_prob = random_aspect_prob
self.random_aspect_range = random_aspect_range
@staticmethod
def get_params(
img,
target_size,
longest,
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11)
):
"""Get parameters
"""
source_size = img.size[::-1] # h, w
h, w = source_size
target_h, target_w = target_size
ratio_h = h / target_h
ratio_w = w / target_w
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
if random_scale_prob > 0 and random.random() < random_scale_prob:
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
ratio_factor = (ratio_factor, ratio_factor)
else:
ratio_factor = (1., 1.)
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
return size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
"""
size = self.get_params(
img, self.size, self.longest,
self.random_scale_prob, self.random_scale_range,
self.random_aspect_prob, self.random_aspect_range
)
img = F.resize(img, size, self.interpolation)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += f', interpolation={self.interpolation})'
format_string += f', longest={self.longest:.3f})'
return format_string
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
"""Center crops and/or pads the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
fill (int, Tuple[int]): Padding color
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = F.get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = F.pad(img, padding_ltrb, fill=fill)
_, image_height, image_width = F.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
class CenterCropOrPad(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size, fill=0):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.fill = fill
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return center_crop_or_pad(img, self.size, fill=self.fill)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def _convert_to_rgb(image):
return image.convert('RGB')
class color_jitter(object):
"""
Apply Color Jitter to the PIL image with a specified probability.
"""
def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):
assert 0. <= p <= 1.
self.p = p
self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
class gray_scale(object):
"""
Apply Gray Scale to the PIL image with a specified probability.
"""
def __init__(self, p=0.2):
assert 0. <= p <= 1.
self.p = p
self.transf = Grayscale(num_output_channels=3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
def image_transform(
image_size: Union[int, Tuple[int, int]],
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_mode: Optional[str] = None,
interpolation: Optional[str] = None,
fill_color: int = 0,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
interpolation = interpolation or 'bicubic'
assert interpolation in ['bicubic', 'bilinear', 'random']
# NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
resize_mode = resize_mode or 'shortest'
assert resize_mode in ('shortest', 'longest', 'squash')
if isinstance(aug_cfg, dict):
aug_cfg = AugmentationCfg(**aug_cfg)
else:
aug_cfg = aug_cfg or AugmentationCfg()
normalize = Normalize(mean=mean, std=std)
if is_train:
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
use_timm = aug_cfg_dict.pop('use_timm', False)
if use_timm:
from timm.data import create_transform # timm can still be optional
if isinstance(image_size, (tuple, list)):
assert len(image_size) >= 2
input_size = (3,) + image_size[-2:]
else:
input_size = (3, image_size, image_size)
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
# drop extra non-timm items
aug_cfg_dict.pop('color_jitter_prob', None)
aug_cfg_dict.pop('gray_scale_prob', None)
train_transform = create_transform(
input_size=input_size,
is_training=True,
hflip=0.,
mean=mean,
std=std,
re_mode='pixel',
interpolation=interpolation,
**aug_cfg_dict,
)
else:
train_transform = [
RandomResizedCrop(
image_size,
scale=aug_cfg_dict.pop('scale'),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
]
if aug_cfg.color_jitter_prob:
assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
train_transform.extend([
color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)
])
if aug_cfg.gray_scale_prob:
train_transform.extend([
gray_scale(aug_cfg.gray_scale_prob)
])
train_transform.extend([
ToTensor(),
normalize,
])
train_transform = Compose(train_transform)
if aug_cfg_dict:
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
return train_transform
else:
if resize_mode == 'longest':
transforms = [
ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
CenterCropOrPad(image_size, fill=fill_color)
]
elif resize_mode == 'squash':
if isinstance(image_size, int):
image_size = (image_size, image_size)
transforms = [
Resize(image_size, interpolation=interpolation_mode),
]
else:
assert resize_mode == 'shortest'
if not isinstance(image_size, (tuple, list)):
image_size = (image_size, image_size)
if image_size[0] == image_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
transforms = [
Resize(image_size[0], interpolation=interpolation_mode)
]
else:
# resize shortest edge to matching target dim for non-square target
transforms = [ResizeKeepRatio(image_size)]
transforms += [CenterCrop(image_size)]
transforms.extend([
_convert_to_rgb,
ToTensor(),
normalize,
])
return Compose(transforms)
def image_transform_v2(
cfg: PreprocessCfg,
is_train: bool,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
return image_transform(
image_size=cfg.size,
is_train=is_train,
mean=cfg.mean,
std=cfg.std,
interpolation=cfg.interpolation,
resize_mode=cfg.resize_mode,
fill_color=cfg.fill_color,
aug_cfg=aug_cfg,
)
from collections import OrderedDict
import math
from typing import Callable, List, Optional, Sequence, Tuple, Union
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from .utils import to_2tuple
from .pos_embed import get_2d_sincos_pos_embed
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
scaled_cosine: bool = False,
scale_heads: bool = False,
logit_scale_max: float = math.log(1. / 0.01),
batch_first: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.logit_scale_max = logit_scale_max
self.batch_first = batch_first
self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
if self.batch_first:
x = x.transpose(0, 1)
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)
k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)
v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)
if attn_mask is not None and attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
else:
if self.use_fsdpa:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
if self.batch_first:
x = x.transpose(0, 1)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class AttentionalPooler(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)
def forward(self, x: torch.Tensor):
N = x.shape[0]
x = self.ln_k(x)
q = self.ln_q(self.query)
out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
return out
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
is_cross_attention: bool = False,
batch_first: bool = True,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class CustomResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
batch_first: bool = True,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = Attention(
d_model,
n_head,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
batch_first=batch_first,
)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def get_reference_weight(self):
return self.mlp.c_fc.weight
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
def _expand_token(token, batch_size: int):
return token.view(1, 1, -1).expand(batch_size, -1, -1)
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
batch_first: bool = True,
):
super().__init__()
self.width = width
self.layers = layers
self.batch_first = batch_first
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
return self.resblocks[0].mlp.c_fc.int8_original_dtype
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if not self.batch_first:
x = x.transpose(0, 1).contiguous() # NLD -> LND
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
if not self.batch_first:
x = x.transpose(0, 1) # LND -> NLD
return x
class CustomTransformer(nn.Module):
""" A custom transformer that can use different block types. """
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
batch_first: bool = True,
block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',
):
super().__init__()
self.width = width
self.layers = layers
self.batch_first = batch_first # run trasnformer stack in batch first (N, L, D)
self.grad_checkpointing = False
if isinstance(block_types, str):
block_types = [block_types] * layers
assert len(block_types) == layers
def _create_block(bt: str):
if bt == 'CustomResidualAttentionBlock':
return CustomResidualAttentionBlock(
width,
heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
else:
assert False
self.resblocks = nn.ModuleList([
_create_block(bt)
for bt in block_types
])
def get_cast_dtype(self) -> torch.dtype:
weight = self.resblocks[0].get_reference_weight()
if hasattr(weight, 'int8_original_dtype'):
return weight.int8_original_dtype
return weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if not self.batch_first:
x = x.transpose(0, 1) # NLD -> LND
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
if not self.batch_first:
x = x.transpose(0, 1) # NLD -> LND
return x
class VisionTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
attentional_pool: bool = False,
attn_pooler_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.,
no_ln_pre: bool = False,
pos_embed_type: str = 'learnable',
pool_type: str = 'tok',
final_ln_after_pool: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
):
super().__init__()
assert pool_type in ('tok', 'avg', 'none')
self.output_tokens = output_tokens
image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
if pos_embed_type == 'learnable':
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
elif pos_embed_type == 'sin_cos_2d':
# fixed sin-cos embedding
assert self.grid_size[0] == self.grid_size[1],\
'currently sin cos 2d pos embedding only supports square input'
self.positional_embedding = nn.Parameter(
torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)
pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)
self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())
else:
raise ValueError
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
if attentional_pool:
if isinstance(attentional_pool, str):
self.attn_pool_type = attentional_pool
self.pool_type = 'none'
if attentional_pool in ('parallel', 'cascade'):
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=1,
)
else:
assert False
else:
self.attn_pool_type = ''
self.pool_type = pool_type
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = None
pool_dim = output_dim
else:
self.attn_pool = None
pool_dim = width
self.pool_type = pool_type
self.ln_post = norm_layer(pool_dim)
self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
self.init_parameters()
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def init_parameters(self):
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
# TODO experiment if default PyTorch init, below, or alternate init is best.
# nn.init.normal_(self.class_embedding, std=self.scale)
# nn.init.normal_(self.positional_embedding, std=self.scale)
#
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# attn_std = self.transformer.width ** -0.5
# fc_std = (2 * self.transformer.width) ** -0.5
# for block in self.transformer.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
#
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=self.scale)
pass
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pool_type == 'avg':
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
elif self.pool_type == 'tok':
pooled, tokens = x[:, 0], x[:, 1:]
else:
pooled = tokens = x
return pooled, tokens
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
# shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = self.transformer(x)
if self.attn_pool is not None:
if self.attn_pool_contrastive is not None:
# This is untested, WIP pooling that should match paper
x = self.ln_post(x) # TBD LN first or separate one after each pool?
tokens = self.attn_pool(x)
if self.attn_pool_type == 'parallel':
pooled = self.attn_pool_contrastive(x)
else:
assert self.attn_pool_type == 'cascade'
pooled = self.attn_pool_contrastive(tokens)
else:
# this is the original OpenCLIP CoCa setup, does not match paper
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
elif self.final_ln_after_pool:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
else:
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
if self.proj is not None:
pooled = pooled @ self.proj
if self.output_tokens:
return pooled, tokens
return pooled
def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
if pool_type == 'first':
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == 'last':
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == 'argmax':
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens
class TextTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
output_dim: int = 512,
embed_cls: bool = False,
no_causal_mask: bool = False,
pad_id: int = 0,
pool_type: str = 'argmax',
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
):
super().__init__()
assert pool_type in ('first', 'last', 'argmax', 'none')
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pad_id = pad_id
self.pool_type = pool_type
self.token_embedding = nn.Embedding(vocab_size, width)
if embed_cls:
self.cls_emb = nn.Parameter(torch.empty(width))
self.num_pos += 1
else:
self.cls_emb = None
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.ln_final = norm_layer(width)
if no_causal_mask:
self.attn_mask = None
else:
self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if self.cls_emb is not None:
nn.init.normal_(self.cls_emb, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)
if self.text_projection.bias is not None:
nn.init.zeros_(self.text_projection.bias)
else:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def build_causal_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
def forward(self, text):
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if self.cls_emb is not None:
seq_len += 1
x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
if attn_mask is not None:
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
x = self.transformer(x, attn_mask=attn_mask)
# x.shape = [batch_size, n_ctx, transformer.width]
if self.cls_emb is not None:
# presence of appended cls embed (CoCa) overrides pool_type, always take last token
pooled, tokens = text_global_pool(x, pool_type='last')
pooled = self.ln_final(pooled) # final LN applied after pooling in this case
else:
x = self.ln_final(x)
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class MultimodalTransformer(Transformer):
def __init__(
self,
width: int,
layers: int,
heads: int,
context_length: int = 77,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_dim: int = 512,
batch_first: bool = True,
):
super().__init__(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
self.context_length = context_length
self.cross_attn = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
is_cross_attention=True,
batch_first=batch_first,
)
for _ in range(layers)
])
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def init_parameters(self):
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
for block in self.transformer.cross_attn:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, image_embs, text_embs):
seq_len = text_embs.shape[1]
if not self.batch_first:
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
else:
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
if not self.batch_first:
text_embs = text_embs.permute(1, 0, 2) # LND -> NLD
out = self.ln_final(text_embs)
if self.text_projection is not None:
out = out @ self.text_projection
return out
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
from itertools import repeat
import collections.abc
import torch
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
def freeze_batch_norm_2d(module, module_match={}, name=''):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = '.'.join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
# Replaces all linear layers with linear_replacement
# TODO: add int8 support for other linear layers including attn and convnets
def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, include_modules, copy_weights)
if isinstance(module, torch.nn.Linear) and name in include_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight.data.copy_(old_module.weight.data)
if model._modules[name].bias is not None:
model._modules[name].bias.data.copy_(old_module.bias)
return model
def convert_int8_model_to_inference_mode(model):
for m in model.modules():
if hasattr(m, 'prepare_for_eval'):
int8_original_dtype = m.weight.dtype
m.prepare_for_eval()
m.int8_original_dtype = int8_original_dtype
\ No newline at end of file
__version__ = '2.26.1'
from functools import partial
from itertools import islice
from typing import Callable, List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
def batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
"""
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
break
yield batch
def build_zero_shot_classifier(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
num_classes_per_batch: Optional[int] = 10,
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names in batches
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
num_classes_per_batch: The number of classes to batch together in each forward, all if None
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
use_format = isinstance(templates[0], str)
num_templates = len(templates)
num_classes = len(classnames)
if use_tqdm:
import tqdm
num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
else:
iter_wrap = iter
def _process_batch(batch_classnames):
num_batch_classes = len(batch_classnames)
texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
texts = tokenizer(texts).to(device)
class_embeddings = model.encode_text(texts, normalize=True)
class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
class_embeddings = class_embeddings.T
return class_embeddings
with torch.no_grad():
if num_classes_per_batch:
batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
zeroshot_weights = torch.cat(batched_embeds, dim=1)
else:
zeroshot_weights = _process_batch(classnames)
return zeroshot_weights
def build_zero_shot_classifier_legacy(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names 1 by 1
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
if use_tqdm:
import tqdm
iter_wrap = tqdm.tqdm
else:
iter_wrap = iter
use_format = isinstance(templates[0], str)
with torch.no_grad():
zeroshot_weights = []
for classname in iter_wrap(classnames):
texts = [template.format(classname) if use_format else template(classname) for template in templates]
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
OPENAI_IMAGENET_TEMPLATES = (
lambda c: f'a bad photo of a {c}.',
lambda c: f'a photo of many {c}.',
lambda c: f'a sculpture of a {c}.',
lambda c: f'a photo of the hard to see {c}.',
lambda c: f'a low resolution photo of the {c}.',
lambda c: f'a rendering of a {c}.',
lambda c: f'graffiti of a {c}.',
lambda c: f'a bad photo of the {c}.',
lambda c: f'a cropped photo of the {c}.',
lambda c: f'a tattoo of a {c}.',
lambda c: f'the embroidered {c}.',
lambda c: f'a photo of a hard to see {c}.',
lambda c: f'a bright photo of a {c}.',
lambda c: f'a photo of a clean {c}.',
lambda c: f'a photo of a dirty {c}.',
lambda c: f'a dark photo of the {c}.',
lambda c: f'a drawing of a {c}.',
lambda c: f'a photo of my {c}.',
lambda c: f'the plastic {c}.',
lambda c: f'a photo of the cool {c}.',
lambda c: f'a close-up photo of a {c}.',
lambda c: f'a black and white photo of the {c}.',
lambda c: f'a painting of the {c}.',
lambda c: f'a painting of a {c}.',
lambda c: f'a pixelated photo of the {c}.',
lambda c: f'a sculpture of the {c}.',
lambda c: f'a bright photo of the {c}.',
lambda c: f'a cropped photo of a {c}.',
lambda c: f'a plastic {c}.',
lambda c: f'a photo of the dirty {c}.',
lambda c: f'a jpeg corrupted photo of a {c}.',
lambda c: f'a blurry photo of the {c}.',
lambda c: f'a photo of the {c}.',
lambda c: f'a good photo of the {c}.',
lambda c: f'a rendering of the {c}.',
lambda c: f'a {c} in a video game.',
lambda c: f'a photo of one {c}.',
lambda c: f'a doodle of a {c}.',
lambda c: f'a close-up photo of the {c}.',
lambda c: f'a photo of a {c}.',
lambda c: f'the origami {c}.',
lambda c: f'the {c} in a video game.',
lambda c: f'a sketch of a {c}.',
lambda c: f'a doodle of the {c}.',
lambda c: f'a origami {c}.',
lambda c: f'a low resolution photo of a {c}.',
lambda c: f'the toy {c}.',
lambda c: f'a rendition of the {c}.',
lambda c: f'a photo of the clean {c}.',
lambda c: f'a photo of a large {c}.',
lambda c: f'a rendition of a {c}.',
lambda c: f'a photo of a nice {c}.',
lambda c: f'a photo of a weird {c}.',
lambda c: f'a blurry photo of a {c}.',
lambda c: f'a cartoon {c}.',
lambda c: f'art of a {c}.',
lambda c: f'a sketch of the {c}.',
lambda c: f'a embroidered {c}.',
lambda c: f'a pixelated photo of a {c}.',
lambda c: f'itap of the {c}.',
lambda c: f'a jpeg corrupted photo of the {c}.',
lambda c: f'a good photo of a {c}.',
lambda c: f'a plushie {c}.',
lambda c: f'a photo of the nice {c}.',
lambda c: f'a photo of the small {c}.',
lambda c: f'a photo of the weird {c}.',
lambda c: f'the cartoon {c}.',
lambda c: f'art of the {c}.',
lambda c: f'a drawing of the {c}.',
lambda c: f'a photo of the large {c}.',
lambda c: f'a black and white photo of a {c}.',
lambda c: f'the plushie {c}.',
lambda c: f'a dark photo of a {c}.',
lambda c: f'itap of a {c}.',
lambda c: f'graffiti of the {c}.',
lambda c: f'a toy {c}.',
lambda c: f'itap of my {c}.',
lambda c: f'a photo of a cool {c}.',
lambda c: f'a photo of a small {c}.',
lambda c: f'a tattoo of the {c}.',
)
# a much smaller subset of above prompts
# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
SIMPLE_IMAGENET_TEMPLATES = (
lambda c: f'itap of a {c}.',
lambda c: f'a bad photo of the {c}.',
lambda c: f'a origami {c}.',
lambda c: f'a photo of the large {c}.',
lambda c: f'a {c} in a video game.',
lambda c: f'art of the {c}.',
lambda c: f'a photo of the small {c}.',
)
IMAGENET_CLASSNAMES = (
"tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
"box turtle", "banded gecko", "green iguana", "Carolina anole",
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
"American alligator", "triceratops", "worm snake", "ring-necked snake",
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"
)
import ast
import json
import logging
import math
import os
import random
import sys
import braceexpand
from dataclasses import dataclass
from multiprocessing import Value
import numpy as np
import pandas as pd
import torch
import torchvision.datasets as datasets
import webdataset as wds
from PIL import Image
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
from torch.utils.data.distributed import DistributedSampler
from webdataset.filters import _shuffle
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
try:
import horovod.torch as hvd
except ImportError:
hvd = None
class CsvDataset(Dataset):
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None):
logging.debug(f'Loading csv data from {input_filename}.')
df = pd.read_csv(input_filename, sep=sep)
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
logging.debug('Done loading data.')
self.tokenize = tokenizer
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
images = self.transforms(Image.open(str(self.images[idx])))
texts = self.tokenize([str(self.captions[idx])])[0]
return images, texts
class SharedEpoch:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value('i', epoch)
def set_value(self, epoch):
self.shared_epoch.value = epoch
def get_value(self):
return self.shared_epoch.value
@dataclass
class DataInfo:
dataloader: DataLoader
sampler: DistributedSampler = None
shared_epoch: SharedEpoch = None
def set_epoch(self, epoch):
if self.shared_epoch is not None:
self.shared_epoch.set_value(epoch)
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
def expand_urls(urls, weights=None):
if weights is None:
expanded_urls = wds.shardlists.expand_urls(urls)
return expanded_urls, None
if isinstance(urls, str):
urllist = urls.split("::")
weights = weights.split('::')
assert len(weights) == len(urllist),\
f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match."
weights = [float(weight) for weight in weights]
all_urls, all_weights = [], []
for url, weight in zip(urllist, weights):
expanded_url = list(braceexpand.braceexpand(url))
expanded_weights = [weight for _ in expanded_url]
all_urls.extend(expanded_url)
all_weights.extend(expanded_weights)
return all_urls, all_weights
else:
all_urls = list(urls)
return all_urls, weights
def get_dataset_size(shards):
shards_list, _ = expand_urls(shards)
dir_path = os.path.dirname(shards_list[0])
sizes_filename = os.path.join(dir_path, 'sizes.json')
len_filename = os.path.join(dir_path, '__len__')
if os.path.exists(sizes_filename):
sizes = json.load(open(sizes_filename, 'r'))
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
elif os.path.exists(len_filename):
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
total_size = ast.literal_eval(open(len_filename, 'r').read())
else:
total_size = None # num samples undefined
# some common dataset sizes (at time of authors last download)
# CC3M (train): 2905954
# CC12M: 10968539
# LAION-400M: 407332084
# LAION-2B (english): 2170337258
num_shards = len(shards_list)
return total_size, num_shards
def get_imagenet(args, preprocess_fns, split):
assert split in ["train", "val", "v2"]
is_train = split == "train"
preprocess_train, preprocess_val = preprocess_fns
if split == "v2":
from imagenetv2_pytorch import ImageNetV2Dataset
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
else:
if is_train:
data_path = args.imagenet_train
preprocess_fn = preprocess_train
else:
data_path = args.imagenet_val
preprocess_fn = preprocess_val
assert data_path
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
if is_train:
idxs = np.zeros(len(dataset.targets))
target_array = np.array(dataset.targets)
k = 50
for c in range(1000):
m = target_array == c
n = len(idxs[m])
arr = np.zeros(n)
arr[:k] = 1
np.random.shuffle(arr)
idxs[m] = arr
idxs = idxs.astype('int')
sampler = SubsetRandomSampler(np.where(idxs)[0])
else:
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.workers,
sampler=sampler,
)
return DataInfo(dataloader=dataloader, sampler=sampler)
def count_samples(dataloader):
os.environ["WDS_EPOCH"] = "0"
n_elements, n_batches = 0, 0
for images, texts in dataloader:
n_batches += 1
n_elements += len(images)
assert len(images) == len(texts)
return n_elements, n_batches
def filter_no_caption_or_no_image(sample):
has_caption = ('txt' in sample)
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)
return has_caption and has_image
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_nothrow(files, handler=handler)
return samples
def pytorch_worker_seed(increment=0):
"""get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
if worker_info is not None:
# favour using the seed already created for pytorch dataloader workers if it exists
seed = worker_info.seed
if increment:
# space out seed increments so they can't overlap across workers in different iterations
seed += increment * max(1, worker_info.num_workers)
return seed
# fallback to wds rank based seed
return wds.utils.pytorch_worker_seed()
_SHARD_SHUFFLE_SIZE = 2000
_SHARD_SHUFFLE_INITIAL = 500
_SAMPLE_SHUFFLE_SIZE = 5000
_SAMPLE_SHUFFLE_INITIAL = 1000
class detshuffle2(wds.PipelineStage):
def __init__(
self,
bufsize=1000,
initial=100,
seed=0,
epoch=-1,
):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
rng = random.Random()
if self.seed < 0:
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
seed = pytorch_worker_seed(epoch)
else:
# This seed to be deterministic AND the same across all nodes/workers in each epoch
seed = self.seed + epoch
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
class ResampledShards2(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
weights=None,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
epoch=-1,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls, weights = expand_urls(urls, weights)
self.urls = urls
self.weights = weights
if self.weights is not None:
assert len(self.urls) == len(self.weights),\
f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match."
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
self.worker_seed = worker_seed
self.deterministic = deterministic
self.epoch = epoch
def __iter__(self):
"""Return an iterator over the shards."""
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
if self.deterministic:
# reset seed w/ epoch if deterministic
if self.worker_seed is None:
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
seed = pytorch_worker_seed(epoch)
else:
seed = self.worker_seed() + epoch
self.rng.seed(seed)
for _ in range(self.nshards):
if self.weights is None:
yield dict(url=self.rng.choice(self.urls))
else:
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0])
def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None):
input_shards = args.train_data if is_train else args.val_data
assert input_shards is not None
resampled = getattr(args, 'dataset_resampled', False) and is_train
num_shards = None
if is_train:
if args.train_num_samples is not None:
num_samples = args.train_num_samples
else:
num_samples, num_shards = get_dataset_size(input_shards)
if not num_samples:
raise RuntimeError(
'Currently, the number of dataset samples must be specified for the training dataset. '
'Please specify it via `--train-num-samples` if no dataset length info is present.')
else:
# Eval will just exhaust the iterator if the size is not specified.
num_samples = args.val_num_samples or 0
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
if is_train and args.train_data_upsampling_factors is not None:
assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."
if resampled:
pipeline = [ResampledShards2(
input_shards,
weights=args.train_data_upsampling_factors,
deterministic=True,
epoch=shared_epoch,
)]
else:
pipeline = [wds.SimpleShardList(input_shards)]
# at this point we have an iterator over all the shards
if is_train:
if not resampled:
pipeline.extend([
detshuffle2(
bufsize=_SHARD_SHUFFLE_SIZE,
initial=_SHARD_SHUFFLE_INITIAL,
seed=args.seed,
epoch=shared_epoch,
),
wds.split_by_node,
wds.split_by_worker,
])
pipeline.extend([
# at this point, we have an iterator over the shards assigned to each worker at each node
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle(
bufsize=_SAMPLE_SHUFFLE_SIZE,
initial=_SAMPLE_SHUFFLE_INITIAL,
),
])
else:
pipeline.extend([
wds.split_by_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
])
pipeline.extend([
wds.select(filter_no_caption_or_no_image),
wds.decode("pilrgb", handler=log_and_continue),
wds.rename(image="jpg;png;jpeg;webp", text="txt"),
wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]),
wds.to_tuple("image", "text"),
wds.batched(args.batch_size, partial=not is_train)
])
dataset = wds.DataPipeline(*pipeline)
if is_train:
if not resampled:
num_shards = num_shards or len(expand_urls(input_shards)[0])
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
# roll over and repeat a few samples to get same number of full batches on each node
round_fn = math.floor if floor else math.ceil
global_batch_size = args.batch_size * args.world_size
num_batches = round_fn(num_samples / global_batch_size)
num_workers = max(1, args.workers)
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
else:
# last batches are partial, eval is done on single (master) node
num_batches = math.ceil(num_samples / args.batch_size)
dataloader = wds.WebLoader(
dataset,
batch_size=None,
shuffle=False,
num_workers=args.workers,
persistent_workers=args.workers > 0,
)
# FIXME not clear which approach is better, with_epoch before vs after dataloader?
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169
# if is_train:
# # roll over and repeat a few samples to get same number of full batches on each node
# global_batch_size = args.batch_size * args.world_size
# num_batches = math.ceil(num_samples / global_batch_size)
# num_workers = max(1, args.workers)
# num_batches = math.ceil(num_batches / num_workers) * num_workers
# num_samples = num_batches * global_batch_size
# dataloader = dataloader.with_epoch(num_batches)
# else:
# # last batches are partial, eval is done on single (master) node
# num_batches = math.ceil(num_samples / args.batch_size)
# add meta-data to dataloader instance for convenience
dataloader.num_batches = num_batches
dataloader.num_samples = num_samples
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = CsvDataset(
input_filename,
preprocess_fn,
img_key=args.csv_img_key,
caption_key=args.csv_caption_key,
sep=args.csv_separator,
tokenizer=tokenizer
)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
class SyntheticDataset(Dataset):
def __init__(
self,
transform=None,
image_size=(224, 224),
caption="Dummy caption",
dataset_size=100,
tokenizer=None,
):
self.transform = transform
self.image_size = image_size
self.caption = caption
self.image = Image.new('RGB', image_size)
self.dataset_size = dataset_size
self.preprocess_txt = lambda text: tokenizer(text)[0]
def __len__(self):
return self.dataset_size
def __getitem__(self, idx):
if self.transform is not None:
image = self.transform(self.image)
return image, self.preprocess_txt(self.caption)
def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
image_size = preprocess_fn.transforms[0].size
dataset = SyntheticDataset(
transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_dataset_fn(data_path, dataset_type):
if dataset_type == "webdataset":
return get_wds_dataset
elif dataset_type == "csv":
return get_csv_dataset
elif dataset_type == "synthetic":
return get_synthetic_dataset
elif dataset_type == "auto":
ext = data_path.split('.')[-1]
if ext in ['csv', 'tsv']:
return get_csv_dataset
elif ext in ['tar']:
return get_wds_dataset
else:
raise ValueError(
f"Tried to figure out dataset type, but failed for extension {ext}.")
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
def get_data(args, preprocess_fns, epoch=0, tokenizer=None):
preprocess_train, preprocess_val = preprocess_fns
data = {}
if args.train_data or args.dataset_type == "synthetic":
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer)
if args.val_data:
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
args, preprocess_val, is_train=False, tokenizer=tokenizer)
if args.imagenet_val is not None:
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
if args.imagenet_v2 is not None:
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
return data
import os
import torch
import torch.distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
def is_using_horovod():
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
return True
else:
return False
def is_using_distributed():
if 'WORLD_SIZE' in os.environ:
return int(os.environ['WORLD_SIZE']) > 1
if 'SLURM_NTASKS' in os.environ:
return int(os.environ['SLURM_NTASKS']) > 1
return False
def world_info_from_env():
local_rank = 0
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
if args.horovod:
assert hvd is not None, "Horovod is not installed"
hvd.init()
args.local_rank = int(hvd.local_rank())
args.rank = hvd.rank()
args.world_size = hvd.size()
args.distributed = True
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
elif is_using_distributed():
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
if torch.cuda.is_available():
if args.distributed and not args.no_set_device_rank:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device
device = torch.device(device)
return device
def broadcast_object(args, obj, src=0):
# broadcast a pickle-able python object from rank-0 to all ranks
if args.horovod:
return hvd.broadcast_object(obj, root_rank=src)
else:
if args.rank == src:
objects = [obj]
else:
objects = [None]
dist.broadcast_object_list(objects, src=src)
return objects[0]
def all_gather_object(args, obj, dst=0):
# gather a pickle-able python object across all ranks
if args.horovod:
return hvd.allgather_object(obj)
else:
objects = [None for _ in range(args.world_size)]
dist.all_gather_object(objects, obj)
return objects
import logging
import os
import multiprocessing
import subprocess
import time
import fsspec
import torch
from tqdm import tqdm
def remote_sync_s3(local_dir, remote_dir):
# skip epoch_latest which can change during sync.
result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
return False
logging.info(f"Successfully synced with S3 bucket")
return True
def remote_sync_fsspec(local_dir, remote_dir):
# FIXME currently this is slow and not recommended. Look into speeding up.
a = fsspec.get_mapper(local_dir)
b = fsspec.get_mapper(remote_dir)
for k in a:
# skip epoch_latest which can change during sync.
if 'epoch_latest.pt' in k:
continue
logging.info(f'Attempting to sync {k}')
if k in b and len(a[k]) == len(b[k]):
logging.debug(f'Skipping remote sync for {k}.')
continue
try:
logging.info(f'Successful sync for {k}.')
b[k] = a[k]
except Exception as e:
logging.info(f'Error during remote sync for {k}: {e}')
return False
return True
def remote_sync(local_dir, remote_dir, protocol):
logging.info('Starting remote sync.')
if protocol == 's3':
return remote_sync_s3(local_dir, remote_dir)
elif protocol == 'fsspec':
return remote_sync_fsspec(local_dir, remote_dir)
else:
logging.error('Remote protocol not known')
return False
def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
while True:
time.sleep(sync_every)
remote_sync(local_dir, remote_dir, protocol)
def start_sync_process(sync_every, local_dir, remote_dir, protocol):
p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
return p
# Note: we are not currently using this save function.
def pt_save(pt_obj, file_path):
of = fsspec.open(file_path, "wb")
with of as f:
torch.save(pt_obj, file_path)
def pt_load(file_path, map_location=None):
if file_path.startswith('s3'):
logging.info('Loading remote checkpoint, which may take a bit.')
of = fsspec.open(file_path, "rb")
with of as f:
out = torch.load(f, map_location=map_location)
return out
def check_exists(file_path):
try:
with fsspec.open(file_path):
pass
except FileNotFoundError:
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment