Unverified Commit b737368d authored by K's avatar K Committed by GitHub
Browse files

feat: PuLID support (#274)



* add pulid

* Add the feature that allows the mixed use of pulid and non-pulid after loading pulid to generate the pipeline.

* Added the feature to load LoRA at any time.

* Organized the directory structure.

* Organized the code.

* Removed unused related code from eva-clip.

* style: apply Ruff formatting

* Refactored code and verified pulid works.

* add pulid tests

* auto detect precision in test

* Updated requirements.txt

* update requirements

* style: reformat the example

* style: reformat the example

* style: rename cb to call_back

* style: format the codes

* style: format the codes

* reformated the codes

* fix the repo forward

* clean some dead codes

* wrap up for pulid

---------
Co-authored-by: default avatarkkkxue <kkkxue@tencent.com>
Co-authored-by: default avatarmuyangli <lmxyy1999@foxmail.com>
parent b4d3f50b
......@@ -202,4 +202,9 @@ cython_debug/
*.log
*.pt
*.nsys-rep
*.ncu-rep
\ No newline at end of file
*.ncu-rep
# Pulid
*.safetensors
*.onnx
.gitattributes
\ No newline at end of file
from types import MethodType
import torch
from diffusers.utils import load_image
from nunchaku.models.pulid.pulid_forward import pulid_forward
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true")
image = pipeline(
"A woman holding a sign that says 'SVDQuant is fast!",
id_image=id_image,
id_weight=1,
num_inference_steps=12,
guidance_scale=3.5,
).images[0]
image.save("flux.1-dev-pulid.png")
......@@ -27,6 +27,30 @@ public:
checkModel();
return net->dtype == Tensor::BF16;
}
pybind11::function residual_callback;
void set_residual_callback(pybind11::function callback) {
pybind11::gil_scoped_acquire gil;
if (!callback || callback.is_none()) {
residual_callback = pybind11::function();
if (net){
net->set_residual_callback(nullptr);
}
return;
}
residual_callback = std::move(callback);
if (net) {
pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::gil_scoped_acquire gil;
torch::Tensor torch_x = to_torch(x);
pybind11::object result = cb(torch_x);
torch::Tensor torch_y = result.cast<torch::Tensor>();
Tensor y = from_torch(torch_y);
return y;
});
} else {
}
}
torch::Tensor forward(
torch::Tensor hidden_states,
......
......@@ -5,7 +5,7 @@
#include "ops.h"
#include "utils.h"
#include <torch/extension.h>
#include "interop/torch.h"
#include <pybind11/pybind11.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......@@ -17,6 +17,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("bf16"),
py::arg("deviceId")
)
.def("set_residual_callback", [](QuantizedFluxModel &self, pybind11::object call_back) {
if (call_back.is_none()) {
self.set_residual_callback(pybind11::function());
} else {
self.set_residual_callback(call_back);
}
})
.def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load,
py::arg("path"),
......@@ -91,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults)
;
py::class_<Tensor>(m, "Tensor");
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
.def(py::init<>())
.def("init", &QuantizedGEMM88::init)
......
# Adapted from https://github.com/ToTheBeginning/PuLID
import math
import torch
from torch import nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttentionCA(nn.Module):
def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, seq_len, _ = latents.shape
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
return self.to_out(out)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, seq_len, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
return self.to_out(out)
class IDFormer(nn.Module):
"""
- perceiver resampler like arch (compared with previous MLP-like arch)
- we concat id embedding (generated by arcface) and query tokens as latents
- latents will attend each other and interact with vit features through cross-attention
- vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
IDFormer layers
"""
def __init__(
self,
dim=1024,
depth=10,
dim_head=64,
heads=16,
num_id_token=5,
num_queries=32,
output_dim=2048,
ff_mult=4,
):
super().__init__()
self.num_id_token = num_id_token
self.dim = dim
self.num_queries = num_queries
assert depth % 5 == 0
self.depth = depth // 5
scale = dim**-0.5
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
for i in range(5):
setattr(
self,
f"mapping_{i}",
nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, dim),
),
)
self.id_embedding_mapping = nn.Sequential(
nn.Linear(1280, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, dim * num_id_token),
)
def forward(self, x, y):
latents = self.latents.repeat(x.size(0), 1, 1)
num_duotu = x.shape[1] if x.ndim == 3 else 1
x = self.id_embedding_mapping(x)
x = x.reshape(-1, self.num_id_token * num_duotu, self.dim)
latents = torch.cat((latents, x), dim=1)
for i in range(5):
vit_feature = getattr(self, f"mapping_{i}")(y[i])
ctx_feature = torch.cat((x, vit_feature), dim=1)
for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
latents = attn(ctx_feature, latents) + latents
latents = ff(latents) + latents
latents = latents[:, : self.num_queries]
latents = latents @ self.proj_out
return latents
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model_and_transforms
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
This diff is collapsed.
import json
import logging
import os
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, convert_to_custom_text_state_dict, CustomCLIP, get_cast_dtype
from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model
from .transform import image_transform
from .utils import resize_clip_pos_embed, resize_eva_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed
_MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = (".json",)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f"*{ext}"))
for cf in config_files:
with open(cf, "r", encoding="utf8") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
_rescan_model_configs() # initial populate of model config registry
def list_models():
"""enumerate available model architectures based on config files"""
return list(_MODEL_CONFIGS.keys())
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
# loading openai CLIP weights when is_openai=True for training
def load_state_dict(
checkpoint_path: str,
map_location: str = "cpu",
model_key: str = "model|module|state_dict",
is_openai: bool = False,
skip_list: list = [],
):
if is_openai:
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
state_dict = model.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
for mk in model_key.split("|"):
if isinstance(checkpoint, dict) and mk in checkpoint:
state_dict = checkpoint[mk]
break
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith("module"):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for k in skip_list:
if k in list(state_dict.keys()):
logging.info(f"Removing key {k} from pretrained checkpoint")
del state_dict[k]
if os.getenv("RoPE") == "1":
for k in list(state_dict.keys()):
if "freqs_cos" in k or "freqs_sin" in k:
del state_dict[k]
return state_dict
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
# detect old format and make compatible with new format
if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"):
state_dict = convert_to_custom_text_state_dict(state_dict)
if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"):
state_dict["logit_scale"] = state_dict["text.logit_scale"]
del state_dict["text.logit_scale"]
# resize_clip_pos_embed for CLIP and open CLIP
if "visual.positional_embedding" in state_dict:
resize_clip_pos_embed(state_dict, model)
# specified to eva_vit_model
elif "visual.pos_embed" in state_dict:
resize_evaclip_pos_embed(state_dict, model)
# resize_clip_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
# logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return incompatible_keys
def load_clip_visual_state_dict(
checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []
):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if not k.startswith("visual."):
del state_dict[k]
for k in list(state_dict.keys()):
if k.startswith("visual."):
new_k = k[7:]
state_dict[new_k] = state_dict[k]
del state_dict[k]
return state_dict
def load_clip_text_state_dict(
checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []
):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if k.startswith("visual."):
del state_dict[k]
return state_dict
def get_pretrained_tag(pretrained_model):
pretrained_model = pretrained_model.lower()
if "laion" in pretrained_model or "open_clip" in pretrained_model:
return "open_clip"
elif "openai" in pretrained_model:
return "clip"
elif "eva" in pretrained_model and "clip" in pretrained_model:
return "eva_clip"
else:
return "other"
def load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=True,
visual_model=None,
text_model=None,
model_key="model|module|state_dict",
skip_list=[],
):
visual_tag = get_pretrained_tag(visual_model)
text_tag = get_pretrained_tag(text_model)
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
visual_incompatible_keys, text_incompatible_keys = None, None
if visual_checkpoint_path:
if visual_tag == "eva_clip" or visual_tag == "open_clip":
visual_state_dict = load_clip_visual_state_dict(
visual_checkpoint_path, is_openai=False, skip_list=skip_list
)
elif visual_tag == "clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
visual_state_dict = load_state_dict(
visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list
)
# resize_clip_pos_embed for CLIP and open CLIP
if "positional_embedding" in visual_state_dict:
resize_visual_pos_embed(visual_state_dict, model)
# specified to EVA model
elif "pos_embed" in visual_state_dict:
resize_eva_pos_embed(visual_state_dict, model)
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
if text_checkpoint_path:
if text_tag == "eva_clip" or text_tag == "open_clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
elif text_tag == "clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
text_state_dict = load_state_dict(
visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list
)
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
return visual_incompatible_keys, text_incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = "",
pretrained_text: str = "",
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == "openai":
pass
else:
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f"Loaded {model_name} model config.")
else:
logging.error(f"Model config for {model_name} not found; available models {list_models()}.")
raise RuntimeError(f"Model config for {model_name} not found.")
if "rope" in model_cfg.get("vision_cfg", {}):
if model_cfg["vision_cfg"]["rope"]:
os.environ["RoPE"] = "1"
else:
os.environ["RoPE"] = "0"
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
cast_dtype = get_cast_dtype(precision)
custom_clip = (
model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"])
)
if custom_clip:
if "hf_model_name" in model_cfg.get("text_cfg", {}):
model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
pretrained_cfg = {}
if pretrained:
checkpoint_path = ""
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False)
else:
error_str = (
f"Pretrained weights ({pretrained}) not found for model {model_name}."
f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}."
)
logging.warning(error_str)
raise RuntimeError(error_str)
else:
visual_checkpoint_path = ""
text_checkpoint_path = ""
if pretrained_image:
pretrained_visual_model = pretrained_visual_model.replace(
"/", "-"
) # for callers using old naming with / in ViT names
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
if "timm_model_name" in model_cfg.get("vision_cfg", {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg["vision_cfg"]["timm_model_pretrained"] = True
elif pretrained_image_cfg:
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_image):
visual_checkpoint_path = pretrained_image
else:
logging.warning(
f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual."
)
raise RuntimeError(
f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual."
)
if pretrained_text:
pretrained_text_model = pretrained_text_model.replace(
"/", "-"
) # for callers using old naming with / in ViT names
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
if pretrained_image_cfg:
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_text):
text_checkpoint_path = pretrained_text
else:
logging.warning(
f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text."
)
raise RuntimeError(
f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text."
)
if visual_checkpoint_path:
logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).")
if text_checkpoint_path:
logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).")
if visual_checkpoint_path or text_checkpoint_path:
load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=False,
visual_model=pretrained_visual_model,
text_model=pretrained_text_model,
model_key="model|module|state_dict",
skip_list=skip_list,
)
if "fp16" in precision or "bf16" in precision:
logging.info(f"convert precision to {precision}")
model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16)
model.to(device=device)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD
if jit:
model = torch.jit.script(model)
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = "",
pretrained_text: str = "",
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_hf=pretrained_hf,
pretrained_visual_model=pretrained_visual_model,
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
skip_list=skip_list,
)
image_mean = image_mean or getattr(model.visual, "image_mean", None)
image_std = image_std or getattr(model.visual, "image_std", None)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
return model, preprocess_train, preprocess_val
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens",
},
"pooler": "mean_pooler",
},
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
}
"""huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
except ImportError:
transformers = None
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
def __init__(
self,
model_name_or_path: str,
output_dim: int,
tokenizer_name: str = None,
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True,
masked_language_modeling: bool = False,
):
super().__init__()
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = pooler_type == "cls_pooler"
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
if masked_language_modeling:
create_func, model_args = (
(AutoModelForMaskedLM.from_pretrained, model_name_or_path)
if pretrained
else (AutoModelForMaskedLM.from_config, self.config)
)
else:
create_func, model_args = (
(AutoModel.from_pretrained, model_name_or_path)
if pretrained
else (AutoModel.from_config, self.config)
)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
if masked_language_modeling:
self.transformer = AutoModelForMaskedLM.from_config(config)
else:
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == "linear":
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == "mlp":
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
def forward(self, x: TensorType) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
return self.proj(pooled_out)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
"""CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
try:
from .hf_model import HFTextEncoder
except ImportError:
HFTextEncoder = None
from .modified_resnet import ModifiedResNet
from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, VisionTransformer, TextTransformer
try:
from apex.normalization import FusedLayerNorm
except ImportError:
FusedLayerNorm = LayerNorm
print("Please 'pip install apex'")
try:
import xformers.ops as xops
except ImportError:
xops = None
print("Please 'pip install xformers'")
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0.0
global_average_pool: bool = (
False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
)
drop_path_rate: Optional[float] = None # drop path rate
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
qkv_bias: bool = True
fusedLN: bool = False
xattn: bool = False
postnorm: bool = False
rope: bool = False
pt_hw_seq_len: int = 16 # 224/14
intp_freq: bool = False
naiveswiglu: bool = False
subln: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = "mlp"
pooler_type: str = "mean_pooler"
masked_language_modeling: bool = False
fusedLN: bool = False
xattn: bool = False
attn_mask: bool = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == "bf16":
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(
embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.eva_model_name:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNorm
visual = EVAVisionTransformer(
img_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
num_classes=embed_dim,
use_mean_pooling=vision_cfg.global_average_pool, # False
init_values=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
embed_dim=vision_cfg.width,
depth=vision_cfg.layers,
num_heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
qkv_bias=vision_cfg.qkv_bias,
drop_path_rate=vision_cfg.drop_path_rate,
norm_layer=partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
xattn=vision_cfg.xattn,
rope=vision_cfg.rope,
postnorm=vision_cfg.postnorm,
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
intp_freq=vision_cfg.intp_freq,
naiveswiglu=vision_cfg.naiveswiglu,
subln=vision_cfg.subln,
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
global_average_pool=vision_cfg.global_average_pool,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
tokenizer_name=text_cfg.hf_tokenizer_name,
proj=text_cfg.proj,
pooler_type=text_cfg.pooler_type,
masked_language_modeling=text_cfg.masked_language_modeling,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer,
xattn=text_cfg.xattn,
attn_mask=text_cfg.attn_mask,
)
return text
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
class CustomCLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
itm_task: bool = False,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if "text_projection" in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(
k.startswith(p)
for p in (
"text_projection",
"positional_embedding",
"token_embedding",
"transformer",
"ln_final",
"logit_scale",
)
):
k = "text." + k
new_state_dict[k] = v
return new_state_dict
return state_dict
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
from .utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict(
[
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion)),
]
)
)
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
import hashlib
import os
import urllib
import warnings
from typing import Dict, Union
from tqdm import tqdm
try:
from huggingface_hub import hf_hub_download
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url="", hf_hub="", filename="", mean=None, std=None):
return dict(
url=url,
hf_hub=hf_hub,
mean=mean,
std=std,
)
_VITB32 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"
),
laion2b_e16=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"
),
laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"),
)
_VITB32_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"
),
)
_VITB16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"
),
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"),
)
_EVAB16 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
),
)
_VITL14 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"
),
laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
)
_EVAL14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
)
_VITL14_336 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
),
)
_EVAL14_336 = dict(
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"),
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"),
)
_EVAg14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
)
_EVAg14_PLUS = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"),
)
_EVAbigE14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
)
_EVAbigE14_PLUS = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
)
_PRETRAINED = {
# "ViT-B-32": _VITB32,
"OpenaiCLIP-B-32": _VITB32,
"OpenCLIP-B-32": _VITB32,
# "ViT-B-32-quickgelu": _VITB32_quickgelu,
"OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
"OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
# "ViT-B-16": _VITB16,
"OpenaiCLIP-B-16": _VITB16,
"OpenCLIP-B-16": _VITB16,
"EVA02-B-16": _EVAB16,
"EVA02-CLIP-B-16": _EVAB16,
# "ViT-B-16-plus-240": _VITB16_PLUS_240,
"OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
# "ViT-L-14": _VITL14,
"OpenaiCLIP-L-14": _VITL14,
"OpenCLIP-L-14": _VITL14,
"EVA02-L-14": _EVAL14,
"EVA02-CLIP-L-14": _EVAL14,
# "ViT-L-14-336": _VITL14_336,
"OpenaiCLIP-L-14-336": _VITL14_336,
"EVA02-CLIP-L-14-336": _EVAL14_336,
# "ViT-H-14": _VITH14,
# "ViT-g-14": _VITg14,
"OpenCLIP-H-14": _VITH14,
"OpenCLIP-g-14": _VITg14,
"EVA01-CLIP-g-14": _EVAg14,
"EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
# "ViT-bigG-14": _VITbigG14,
"OpenCLIP-bigG-14": _VITbigG14,
"EVA02-CLIP-bigE-14": _EVAbigE14,
"EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace("-", "_")
def list_pretrained_tags_by_model(model: str):
"""return all pretrain tags for the specified model architecture"""
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if "openaipublic" in url:
expected_sha256 = url.split("/")[-2]
elif "mlfoundations" in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ""
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(
expected_sha256
):
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
"Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`."
)
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = "open_clip_pytorch_model.bin",
revision=None,
cache_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
):
target = ""
if not cfg:
return target
download_url = cfg.get("url", "")
download_hf_hub = cfg.get("hf_hub", "")
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ""
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target
import logging
from math import pi
import torch
from einops import rearrange, repeat
from torch import nn
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
patch_dropout=0.0,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs = torch.einsum("..., f -> ... f", t, freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.patch_dropout = patch_dropout
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
def forward(self, t, patch_indices_keep=None):
if patch_indices_keep is not None:
batch = t.size()[0]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
return t * freqs_cos + rotate_half(t) * freqs_sin
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomResizedCrop,
Resize,
ToTensor,
)
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
class ResizeMaxSize(nn.Module):
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == "min" else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
return img
def _convert_to_rgb(image):
return image.convert("RGB")
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose(
[
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
]
)
else:
if resize_longest_max:
transforms = [ResizeMaxSize(image_size, fill=fill_color)]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend(
[
_convert_to_rgb,
ToTensor(),
normalize,
]
)
return Compose(transforms)
import logging
import math
import os
from collections import OrderedDict
from typing import Callable, Optional, Sequence
import torch
from torch import nn
from torch.nn import functional as F
try:
from timm.models.layers import trunc_normal_
except ImportError:
from timm.layers import trunc_normal_
from .utils import to_2tuple
if os.getenv("ENV_TYPE") == "deepspeed":
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except ImportError:
print("Please 'pip install deepspeed'")
deepspeed = None
from torch.utils.checkpoint import checkpoint
else:
from torch.utils.checkpoint import checkpoint
try:
import xformers.ops as xops
except ImportError:
xops = None
print("Please 'pip install xformers'")
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.0
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
def forward(self, x):
if not self.training or self.prob == 0.0:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
if self.training and os.getenv("RoPE") == "1":
return x, patch_indices_keep
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
logit_scale_max=math.log(1.0 / 0.01),
attn_drop=0.0,
proj_drop=0.0,
xattn=False,
rope=False,
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.logit_scale_max = logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
self.rope = rope
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
if self.xattn:
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
x = xops.memory_efficient_attention(
q,
k,
v,
p=self.xattn_drop,
scale=self.scale if self.logit_scale is None else None,
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
)
else:
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if xattn:
self.attn = Attention(d_model, n_head, xattn=True)
else:
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model)),
]
)
)
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.xattn = xattn
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
if self.xattn:
return self.attn(x, attn_mask=attn_mask)
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
xattn=xattn,
)
for _ in range(layers)
]
)
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
patch_dropout: float = 0.0,
global_average_pool: bool = False,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.image_size = to_2tuple(image_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
self.ln_pre = norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
xattn=xattn,
)
self.global_average_pool = global_average_pool
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def forward(self, x: torch.Tensor, return_all_features: bool = False):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
if not return_all_features:
if self.global_average_pool:
x = x.mean(dim=1) # x = x[:,1:,:].mean(dim=1)
else:
x = x[:, 0]
x = self.ln_post(x)
if self.proj is not None:
x = x @ self.proj
return x
class TextTransformer(nn.Module):
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
ls_init_value: float = None,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
attn_mask: bool = True,
):
super().__init__()
self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
xattn=xattn,
)
self.xattn = xattn
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
if attn_mask:
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
else:
self.attn_mask = None
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, text, return_all_features: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
# x = self.transformer(x) # no attention mask is applied
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
if not return_all_features:
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
import collections.abc
import logging
import math
from itertools import repeat
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
# open CLIP
def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get("visual.positional_embedding", None)
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict["visual.positional_embedding"] = new_pos_embed
def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get("positional_embedding", None)
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict["positional_embedding"] = new_pos_embed
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# interpolate position embedding
if "visual.pos_embed" in state_dict:
pos_embed_checkpoint = state_dict["visual.pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict["visual.pos_embed"] = new_pos_embed
patch_embed_proj = state_dict["visual.patch_embed.proj.weight"]
patch_size = model.visual.patch_embed.patch_size
state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(
patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False
)
def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# interpolate position embedding
if "pos_embed" in state_dict:
pos_embed_checkpoint = state_dict["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict["pos_embed"] = new_pos_embed
patch_embed_proj = state_dict["patch_embed.proj.weight"]
patch_size = model.visual.patch_embed.patch_size
state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(
patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False
)
def freeze_batch_norm_2d(module, module_match={}, name=""):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = ".".join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
# Adapted from https://github.com/ToTheBeginning/PuLID
import torch
from typing import Any, Dict, Optional, Union
from diffusers.models.modeling_outputs import Transformer2DModelOutput
import logging
logger = logging.getLogger(__name__)
def pulid_forward(
self,
hidden_states: torch.Tensor,
id_embeddings=None,
id_weight=None,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
nunchaku_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = nunchaku_block(
hidden_states=hidden_states,
id_embeddings=id_embeddings,
id_weight=id_weight,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment