Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
import hashlib
import os
import urllib
import warnings
from typing import Dict, Union
from tqdm import tqdm
try:
from huggingface_hub import hf_hub_download
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url="", hf_hub="", filename="", mean=None, std=None):
return dict(
url=url,
hf_hub=hf_hub,
mean=mean,
std=std,
)
_VITB32 = dict(
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
laion2b_e16=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"),
)
_VITB32_quickgelu = dict(
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
)
_VITB16 = dict(
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"),
)
_EVAB16 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
)
_VITL14 = dict(
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
)
_EVAL14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
)
_VITL14_336 = dict(
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
)
_EVAL14_336 = dict(
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"),
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"),
)
_EVAg14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
)
_EVAg14_PLUS = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"),
)
_EVAbigE14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
)
_EVAbigE14_PLUS = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
)
_EVA_8B = dict(
eva=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_8B_psz14.bin"),
eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_CLIP_8B_psz14_s9B.pt"),
)
_EVA_8B_PLUS = dict(
eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B-448/EVA_CLIP_8B_psz14_plus_s0.6B.pt"),
)
_PRETRAINED = {
# "ViT-B-32": _VITB32,
"OpenaiCLIP-B-32": _VITB32,
"OpenCLIP-B-32": _VITB32,
# "ViT-B-32-quickgelu": _VITB32_quickgelu,
"OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
"OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
# "ViT-B-16": _VITB16,
"OpenaiCLIP-B-16": _VITB16,
"OpenCLIP-B-16": _VITB16,
"EVA02-B-16": _EVAB16,
"EVA02-CLIP-B-16": _EVAB16,
# "ViT-B-16-plus-240": _VITB16_PLUS_240,
"OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
# "ViT-L-14": _VITL14,
"OpenaiCLIP-L-14": _VITL14,
"OpenCLIP-L-14": _VITL14,
"EVA02-L-14": _EVAL14,
"EVA02-CLIP-L-14": _EVAL14,
# "ViT-L-14-336": _VITL14_336,
"OpenaiCLIP-L-14-336": _VITL14_336,
"EVA02-CLIP-L-14-336": _EVAL14_336,
# "ViT-H-14": _VITH14,
# "ViT-g-14": _VITg14,
"OpenCLIP-H-14": _VITH14,
"OpenCLIP-g-14": _VITg14,
"EVA01-CLIP-g-14": _EVAg14,
"EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
# "ViT-bigG-14": _VITbigG14,
"OpenCLIP-bigG-14": _VITbigG14,
"EVA02-CLIP-bigE-14": _EVAbigE14,
"EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
"EVA-CLIP-8B": _EVA_8B,
"EVA-CLIP-8B-448": _EVA_8B_PLUS,
"EVA-CLIP-8B-plus": _EVA_8B_PLUS,
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace("-", "_")
def list_pretrained(as_str: bool = False):
"""returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
def list_pretrained_models_by_tag(tag: str):
"""return all models having the specified pretrain tag"""
models = []
tag = _clean_tag(tag)
for k in _PRETRAINED.keys():
if tag in _PRETRAINED[k]:
models.append(k)
return models
def list_pretrained_tags_by_model(model: str):
"""return all pretrain tags for the specified model architecture"""
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def is_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return False
return _clean_tag(tag) in _PRETRAINED[model]
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def get_pretrained_url(model: str, tag: str):
cfg = get_pretrained_cfg(model, _clean_tag(tag))
return cfg.get("url", "")
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if "openaipublic" in url:
expected_sha256 = url.split("/")[-2]
elif "mlfoundations" in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ""
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError("Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.")
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = "open_clip_pytorch_model.bin",
revision=None,
cache_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
):
target = ""
if not cfg:
return target
download_url = cfg.get("url", "")
download_hf_hub = cfg.get("hf_hub", "")
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ""
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target
from math import pi
import torch
from torch import nn
from einops import rearrange, repeat
import logging
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs_h = torch.einsum("..., f -> ... f", t, freqs)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = torch.einsum("..., f -> ... f", t, freqs)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
self.register_buffer("freqs_cos", freqs.cos())
self.register_buffer("freqs_sin", freqs.sin())
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
def forward(self, t, start_index=0):
rot_dim = self.freqs_cos.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
return torch.cat((t_left, t, t_right), dim=-1)
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs = torch.einsum("..., f -> ... f", t, freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.patch_dropout = patch_dropout
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
def forward(self, t, patch_indices_keep=None):
if patch_indices_keep is not None:
batch = t.size()[0]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
return t * freqs_cos + rotate_half(t) * freqs_sin
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from .utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
"""timm model adapter
# FIXME this adapter is a work in progress, may change in ways that break weight compat
"""
def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
self.trunk = timm.create_model(model_name, pretrained=pretrained)
feat_size = self.trunk.default_cfg.get("pool_size", None)
feature_ndim = 1 if not feat_size else 2
if pool in ("abs_attn", "rot_attn"):
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool="")
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
if pool == "abs_attn":
head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == "rot_attn":
head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
else:
assert proj, "projection layer needed if non-attention pooling is used."
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == "linear":
head_layers["drop"] = nn.Dropout(drop)
head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == "mlp":
head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
"""lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`")
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning("grad checkpointing not supported for this timm image tower, continuing without...")
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x
""" CLIP tokenizer
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import gzip
import html
import os
from functools import lru_cache
from typing import Union, List
import ftfy
import regex as re
import torch
# https://stackoverflow.com/q/62691279
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
if not special_tokens:
special_tokens = ["<start_of_text>", "<end_of_text>"]
else:
special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t: t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
return text
_tokenizer = SimpleTokenizer()
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<start_of_text>"]
eot_token = _tokenizer.encoder["<end_of_text>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, : len(tokens)] = torch.tensor(tokens)
return result
class HFTokenizer:
"HuggingFace tokenizer wrapper"
def __init__(self, tokenizer_name: str):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = self.tokenizer(texts, return_tensors="pt", max_length=context_length, padding="max_length", truncation=True).input_ids
return input_ids
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
class ResizeMaxSize(nn.Module):
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == "min" else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
return img
def _convert_to_rgb(image):
return image.convert("RGB")
# class CatGen(nn.Module):
# def __init__(self, num=4):
# self.num = num
# def mixgen_batch(image, text):
# batch_size = image.shape[0]
# index = np.random.permutation(batch_size)
# cat_images = []
# for i in range(batch_size):
# # image mixup
# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
# # text concat
# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
# text = torch.stack(text)
# return image, text
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose(
[
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
]
)
else:
if resize_longest_max:
transforms = [ResizeMaxSize(image_size, fill=fill_color)]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend(
[
_convert_to_rgb,
ToTensor(),
normalize,
]
)
return Compose(transforms)
import os
import logging
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
try:
from timm.models.layers import trunc_normal_
except:
from timm.layers import trunc_normal_
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
from .utils import to_2tuple
if os.getenv("ENV_TYPE") == "deepspeed":
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except:
print("Please 'pip install deepspeed'")
deepspeed = None
from torch.utils.checkpoint import checkpoint
else:
from torch.utils.checkpoint import checkpoint
try:
import xformers.ops as xops
except ImportError:
xops = None
# print("Please 'pip install xformers'")
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor):
output = F.layer_norm(
x.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(x)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.0
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
def forward(self, x):
if not self.training or self.prob == 0.0:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
if self.training and os.getenv("RoPE") == "1":
return x, patch_indices_keep
return x
def _in_projection_packed(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
b: Optional[torch.Tensor] = None,
):
"""
https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
"""
E = q.size(-1)
if k is v:
if q is k:
# self-attention
return F.linear(q, w, b).chunk(3, dim=-1)
else:
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False, rope=False):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.logit_scale_max = logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
self.rope = rope
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
if self.xattn:
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
x = xops.memory_efficient_attention(
q,
k,
v,
p=self.xattn_drop,
scale=self.scale if self.logit_scale is None else None,
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
)
else:
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class CustomAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=True, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.logit_scale_max = logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
N_q, B_q, C_q = q.shape
N_k, B_k, C_k = k.shape
N_v, B_v, C_v = v.shape
if self.xattn:
# B, N, C -> B, N, num_heads, C
q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
x = xops.memory_efficient_attention(q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None)
else:
# B*H, L, C
q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
# B*H, N_q, N_k
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
attn = attn.view(-1, N_q, N_k)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
x = x.view(-1, N_q, C_q)
x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class CustomResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
cross_attn: bool = False,
xattn: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
self.attn = CustomAttention(d_model, n_head, qkv_bias=True, attn_drop=0.0, proj_drop=0.0, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, xattn=xattn)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
q = q + self.ls_2(self.mlp(self.ln_2(q)))
return q
class CustomTransformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = True,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
cross_attn: bool = False,
xattn: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.xattn = xattn
self.resblocks = nn.ModuleList(
[
CustomResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
scale_cosine_attn=scale_cosine_attn,
scale_heads=scale_heads,
scale_attn=scale_attn,
scale_fc=scale_fc,
cross_attn=cross_attn,
xattn=xattn,
)
for _ in range(layers)
]
)
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
if k is None and v is None:
k = v = q
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
q = checkpoint(r, q, k, v, attn_mask)
else:
q = r(q, k, v, attn_mask=attn_mask)
return q
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if xattn:
self.attn = Attention(d_model, n_head, xattn=True)
else:
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.xattn = xattn
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
if self.xattn:
return self.attn(x, attn_mask=attn_mask)
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) for _ in range(layers)])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
patch_dropout: float = 0.0,
global_average_pool: bool = False,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.image_size = to_2tuple(image_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
self.ln_pre = norm_layer(width)
self.transformer = Transformer(width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
self.global_average_pool = global_average_pool
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def get_num_layers(self):
return self.transformer.layers
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {"positional_embedding", "class_embedding"}
def forward(self, x: torch.Tensor, return_all_features: bool = False):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
if not return_all_features:
if self.global_average_pool:
x = x.mean(dim=1) # x = x[:,1:,:].mean(dim=1)
else:
x = x[:, 0]
x = self.ln_post(x)
if self.proj is not None:
x = x @ self.proj
return x
class TextTransformer(nn.Module):
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
ls_init_value: float = None,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
attn_mask: bool = True,
):
super().__init__()
self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
self.transformer = Transformer(width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
self.xattn = xattn
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
if attn_mask:
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
else:
self.attn_mask = None
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
# return {'positional_embedding', 'token_embedding'}
return {"positional_embedding"}
def get_num_layers(self):
return self.transformer.layers
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, text, return_all_features: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
# x = self.transformer(x) # no attention mask is applied
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
if not return_all_features:
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
from itertools import repeat
import collections.abc
import logging
import math
import numpy as np
import torch
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
import torch.nn.functional as F
# open CLIP
def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get("visual.positional_embedding", None)
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict["visual.positional_embedding"] = new_pos_embed
def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get("positional_embedding", None)
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict["positional_embedding"] = new_pos_embed
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
all_keys = list(state_dict.keys())
# interpolate position embedding
if "visual.pos_embed" in state_dict:
pos_embed_checkpoint = state_dict["visual.pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
# num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict["visual.pos_embed"] = new_pos_embed
patch_embed_proj = state_dict["visual.patch_embed.proj.weight"]
patch_size = model.visual.patch_embed.patch_size
state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
all_keys = list(state_dict.keys())
# interpolate position embedding
if "pos_embed" in state_dict:
pos_embed_checkpoint = state_dict["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
# num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict["pos_embed"] = new_pos_embed
patch_embed_proj = state_dict["patch_embed.proj.weight"]
patch_size = model.visual.patch_embed.patch_size
state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
def resize_rel_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
all_keys = list(state_dict.keys())
for key in all_keys:
if "relative_position_index" in key:
state_dict.pop(key)
if "relative_position_bias_table" in key:
rel_pos_bias = state_dict[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.visual.state_dict()[key].size()
dst_patch_shape = model.visual.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = F.interpolate.interp2d(x, y, z, kind="cubic")
all_rel_pos_bias.append(torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
# interpolate position embedding
if "pos_embed" in state_dict:
pos_embed_checkpoint = state_dict["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict["pos_embed"] = new_pos_embed
patch_embed_proj = state_dict["patch_embed.proj.weight"]
patch_size = model.visual.patch_embed.patch_size
state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
def freeze_batch_norm_2d(module, module_match={}, name=""):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = ".".join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
def is_logging(args):
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
return is_master
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor.
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
@staticmethod
def forward(ctx, tensor, rank, world_size):
tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensors_gather, tensor)
ctx.rank = rank
ctx.batch_size = tensor.shape[0]
return torch.cat(tensors_gather, 0)
@staticmethod
def backward(ctx, grad_output):
return (grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], None, None)
allgather = AllGather.apply
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
# not tested yet
import math
from transformers import CLIPImageProcessor
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from .eva_clip import create_model_and_transforms, get_model_config
import torch
import torchvision
import time
from llava.utils import rank0_print
class EvaViTWrapper(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.pretrained = args.vision_tower_pretrained
self.args = args
self.select_layer = args.mm_vision_select_layer
if self.select_layer < -1:
self.select_layer += 1
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
self.model_config = get_model_config(self.vision_tower_name)
if not delay_load:
rank0_print(f"Loading vision tower: {vision_tower}")
self.load_model()
elif getattr(args, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
self.load_model()
def load_model(self):
rank0_print(f"Loading: {self.vision_tower_name}")
rank0_print(f"Pretrained: {self.pretrained}")
time_start = time.time()
model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16")
time_end = time.time()
rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s")
self.device = next(model.parameters()).device
self.dtype = next(model.parameters()).dtype
if self.device.type != "meta":
model = model.to("cuda")
self.vision_tower = model.visual
resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
self.resize_transform_size = resize_transform.size
self.image_processor = CLIPImageProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
crop_size=resize_transform.size,
size={"shortest_edge": resize_transform.size},
image_mean=list(normalize_transform.mean),
image_std=list(normalize_transform.std),
)
rank0_print(f"Loaded image processor: {self.image_processor}")
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_features):
select_feature_type = self.select_feature
# if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
# select_every_k_layer = len(image_features) // 4
# image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1)
# select_feature_type = select_feature_type.replace("slicefour_", "")
# elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
# select_layers = [-1, -4, -7, -10, 6]
# image_features = torch.cat([image_features[i] for i in select_layers], dim=-1)
# select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
# else:
# image_features = image_features[self.select_layer]
if select_feature_type == "patch":
image_features = image_features[:, 1:]
elif select_feature_type == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {select_feature_type}")
return image_features
def train(self, mode=True):
self.training = mode
if self.is_loaded:
self.vision_tower.eval()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True)
image_features = self.feature_select(image_features).to(self.dtype)
image_features.append(image_features)
else:
image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True)
image_features = self.feature_select(image_features).to(self.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def hidden_size(self):
return self.model_config["vision_cfg"]["width"]
@property
def num_patches(self):
return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2
@property
def num_patches_per_side(self):
return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
@property
def config(self):
return self.model_config
@property
def image_size(self):
return self.model_config["vision_cfg"]["image_size"]
import torch
import torch.nn as nn
from .eva_clip_processors import EvaClipImageTrainProcessor
from .eva_vit import EVAEncoderWrapper
from .factory import list_models, add_model_config, get_model_config
from llava.utils import rank0_print
class EvaClipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.vision_tower_pretrained = args.vision_tower_pretrained
self.config = get_model_config(vision_tower)
if not delay_load:
rank0_print(f"Loading EVA ViT: {self.vision_tower_name}")
self.load_model()
elif getattr(args, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
self.load_model()
else:
self.cfg_only = self.config
def load_model(self, device_map=None):
rank0_print(f"Pretrained: {self.vision_tower_pretrained}")
self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"])
self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config)
rank0_print(f"Loaded image processor: {self.image_processor}")
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype)
image_features.append(image_feature)
else:
image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
return image_features
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def hidden_size(self):
return self.config["vision_cfg"]["width"]
@property
def num_patches(self):
return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2
@property
def num_patches_per_side(self):
return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]
@property
def image_size(self):
return self.config["vision_cfg"]["image_size"]
"""
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
"""
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers.image_processing_utils import BatchFeature
from PIL import Image
from transformers.image_transforms import convert_to_rgb
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
class EvaClipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
self.normalize = transforms.Normalize(self.mean, self.std)
@property
def image_mean(self):
return self.mean
class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
convert_to_rgb,
transforms.Resize(
image_size,
interpolation=InterpolationMode.BICUBIC,
),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
]
)
self.image_size = image_size
def preprocess(self, images, return_tensors):
if isinstance(images, Image.Image):
images = [images]
else:
assert isinstance(images, list)
transformed_images = [self.transform(image).numpy() for image in images]
data = {"pixel_values": transformed_images}
return BatchFeature(data=data, tensor_type=return_tensors)
def __call__(self, item):
return self.transform(item)
@property
def crop_size(self):
return {"height": self.image_size, "width": self.image_size}
@property
def size(self):
return {"shortest_edge": self.image_size}
"""
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
"""
from math import pi
import torch
from torch import nn
from einops import rearrange, repeat
import logging
from llava.utils import rank0_print
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs = torch.einsum("..., f -> ... f", t, freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.patch_dropout = patch_dropout
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
def forward(self, t, patch_indices_keep=None):
if patch_indices_keep is not None:
batch = t.size()[0]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
return t * freqs_cos + rotate_half(t) * freqs_sin
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.0
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
def forward(self, x):
if not self.training or self.prob == 0.0:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
if self.training and os.getenv("RoPE") == "1":
return x, patch_indices_keep
return x
# --------------------------------------------------------
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import os
import torch.nn as nn
import torch.nn.functional as F
try:
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
except:
from timm.layers import drop_path, to_2tuple, trunc_normal_
if os.getenv("ENV_TYPE") == "deepspeed":
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except:
from torch.utils.checkpoint import checkpoint
else:
from torch.utils.checkpoint import checkpoint
try:
import xformers.ops as xops
except ImportError:
xops = None
# print("Please 'pip install xformers'")
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
drop=0.0,
subln=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.ffn_ln(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLU(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.w3 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = self.act(x1) * x2
x = self.ffn_ln(hidden)
x = self.w3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.subln = subln
if self.subln:
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
else:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
# self.proj = nn.Linear(all_head_dim, all_head_dim)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
self.rope = rope
def forward(self, x, rel_pos_bias=None, attn_mask=None):
B, N, C = x.shape
if self.subln:
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
else:
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
q, k, v = qkv[0], qkv[1], qkv[2]
if self.rope:
# slightly fast impl
q_t = q[:, :, 1:, :]
ro_q_t = self.rope(q_t)
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
k_t = k[:, :, 1:, :]
ro_k_t = self.rope(k_t)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
if self.xattn and xops is not None:
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = xops.memory_efficient_attention(
q,
k,
v,
p=self.xattn_drop,
scale=self.scale,
)
x = x.reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias.type_as(attn)
if attn_mask is not None:
attn_mask = attn_mask.bool()
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
window_size=None,
attn_head_dim=None,
xattn=False,
rope=None,
postnorm=False,
subln=False,
naiveswiglu=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if naiveswiglu:
self.mlp = SwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer,
)
else:
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
self.postnorm = postnorm
def forward(self, x, rel_pos_bias=None, attn_mask=None):
if self.gamma_1 is None:
if self.postnorm:
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
if self.postnorm:
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class EVAVisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=None,
patch_dropout=0.0,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
rope=False,
use_mean_pooling=True,
init_scale=0.001,
grad_checkpointing=False,
xattn=False,
postnorm=False,
pt_hw_seq_len=16,
intp_freq=False,
naiveswiglu=False,
subln=False,
):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
if rope:
half_head_dim = embed_dim // num_heads // 2
hw_seq_len = img_size // patch_size
self.rope = VisionRotaryEmbeddingFast(
dim=half_head_dim,
pt_seq_len=pt_hw_seq_len,
ft_seq_len=hw_seq_len if intp_freq else None,
# patch_dropout=patch_dropout
)
else:
self.rope = None
self.naiveswiglu = naiveswiglu
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
xattn=xattn,
rope=self.rope,
postnorm=postnorm,
subln=subln,
naiveswiglu=naiveswiglu,
)
for i in range(depth)
]
)
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
# trunc_normal_(self.mask_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=0.02)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
self.grad_checkpointing = grad_checkpointing
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
if self.naiveswiglu:
rescale(layer.mlp.w3.weight.data, layer_id + 1)
else:
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def get_cast_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, "partial locking not currently supported for this model"
for param in self.parameters():
param.requires_grad = False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, return_all_features=False):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
if os.getenv("RoPE") == "1":
if self.training and not isinstance(self.patch_dropout, nn.Identity):
x, patch_indices_keep = self.patch_dropout(x)
# Directly pass patch_indices_keep to self.rope.forward
x = self.rope.forward(x, patch_indices_keep=patch_indices_keep)
else:
# Pass None or omit the patch_indices_keep argument for default behavior
x = self.rope.forward(x, patch_indices_keep=None)
x = self.patch_dropout(x)
else:
x = self.patch_dropout(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for i, blk in enumerate(self.blocks):
if i == len(self.blocks) - 1:
continue
if self.grad_checkpointing:
x = checkpoint(blk, x, (rel_pos_bias,))
else:
x = blk(x, rel_pos_bias=rel_pos_bias)
if not return_all_features:
x = self.norm(x)
if self.fc_norm is not None:
return self.fc_norm(x.mean(1))
else:
return x[:, 0]
return x
def forward(self, x, return_all_features=False):
if return_all_features:
return self.forward_features(x, return_all_features)
x = self.forward_features(x)
x = self.head(x)
return x
def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
if is_openai:
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
state_dict = model.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
for mk in model_key.split("|"):
if isinstance(checkpoint, dict) and mk in checkpoint:
state_dict = checkpoint[mk]
break
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith("module"):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for k in skip_list:
if k in list(state_dict.keys()):
logging.info(f"Removing key {k} from pretrained checkpoint")
del state_dict[k]
if os.getenv("RoPE") == "1":
for k in list(state_dict.keys()):
if "freqs_cos" in k or "freqs_sin" in k:
del state_dict[k]
return state_dict
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
# for k in list(state_dict.keys()):
# if not k.startswith("visual."):
# del state_dict[k]
# for k in list(state_dict.keys()):
# if k.startswith("visual."):
# new_k = k[7:]
# state_dict[new_k] = state_dict[k]
# del state_dict[k]
return state_dict
from dataclasses import dataclass
from typing import Optional, Tuple, Union
try:
from apex.normalization import FusedLayerNorm
except:
FusedLayerNorm = LayerNorm
# print("Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .")
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
drop_path_rate: Optional[float] = None # drop path rate
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
qkv_bias: bool = True
fusedLN: bool = False
xattn: bool = False
postnorm: bool = False
rope: bool = False
pt_hw_seq_len: int = 16 # 224/14
intp_freq: bool = False
naiveswiglu: bool = False
subln: bool = False
def create_norm_layer_factory(use_fused_ln, eps=1e-6):
# Otherwise, use the standard LayerNorm
return lambda num_features: nn.LayerNorm(num_features, eps=eps)
def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg, **kwargs):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
if vision_cfg.eva_model_name:
vision_heads = vision_cfg.width // vision_cfg.head_width
# Determine the appropriate norm layer factory based on the configuration
norm_layer_factory = create_norm_layer_factory(vision_cfg.fusedLN, eps=1e-6)
visual = EVAVisionTransformer(
img_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
num_classes=embed_dim,
use_mean_pooling=vision_cfg.global_average_pool, # False
init_values=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
embed_dim=vision_cfg.width,
depth=vision_cfg.layers,
num_heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
qkv_bias=vision_cfg.qkv_bias,
drop_path_rate=vision_cfg.drop_path_rate,
norm_layer=norm_layer_factory,
xattn=vision_cfg.xattn,
rope=vision_cfg.rope,
postnorm=vision_cfg.postnorm,
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
intp_freq=vision_cfg.intp_freq,
naiveswiglu=vision_cfg.naiveswiglu,
subln=vision_cfg.subln,
)
state_dict = load_clip_visual_state_dict(vision_tower_path)
incompatible_keys = visual.load_state_dict(state_dict, strict=False)
rank0_print("EVA-CLIP incompatible_keys:", incompatible_keys)
return visual
class EVAEncoderWrapper(nn.Module):
def __init__(self, vision_tower_pretrained, config):
super(EVAEncoderWrapper, self).__init__()
self.config = config
self.config["vision_tower_path"] = vision_tower_pretrained
self.model = _build_vision_tower(**self.config)
def forward(self, image, **kwargs):
encode = self.model(image, return_all_features=True)[:, 1:, :] # remove the CLS token
return encode
@property
def dtype(self):
return list(self.parameters())[-1].dtype
@property
def device(self):
return list(self.parameters())[-1].device
import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union, Dict, Any
import torch
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = (".json",)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f"*{ext}"))
for cf in config_files:
with open(cf, "r", encoding="utf8") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
_rescan_model_configs() # initial populate of model config registry
def list_models():
"""enumerate available model architectures based on config files"""
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
"""add model config path or file and update registry"""
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
{
"embed_dim": 1536,
"vision_cfg": {
"image_size": 224,
"layers": 48,
"width": 5120,
"head_width": 128,
"mlp_ratio": 5,
"patch_size": 14,
"eva_model_name": "eva-clip-18b-14-x",
"drop_path_rate": 0,
"qkv_bias": false,
"xattn": true,
"postnorm": true,
"fusedLN": false,
"use_rms_norm": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": false
}
}
\ No newline at end of file
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 448,
"layers": 32,
"width": 4096,
"head_width": 128,
"mlp_ratio": 5,
"patch_size": 14,
"eva_model_name": "eva-clip-8b-14-plus-x",
"drop_path_rate": 0,
"qkv_bias": false,
"xattn": true,
"postnorm": false,
"fusedLN": false,
"use_rms_norm": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": false
}
}
\ No newline at end of file
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 4096,
"head_width": 128,
"mlp_ratio": 5,
"patch_size": 14,
"eva_model_name": "eva-clip-8b-14-x",
"drop_path_rate": 0,
"qkv_bias": false,
"xattn": true,
"postnorm": false,
"fusedLN": false,
"use_rms_norm": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": false
}
}
\ No newline at end of file
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16,
"eva_model_name": "eva-clip-b-16",
"ls_init_value": 0.1,
"drop_path_rate": 0.0
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0.4,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"head_width": 64,
"patch_size": 16,
"mlp_ratio": 2.6667,
"eva_model_name": "eva-clip-b-16-X",
"drop_path_rate": 0.0,
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"xattn": true,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment