Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
"""CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
try:
from .hf_model import HFTextEncoder
except ImportError:
HFTextEncoder = None
from .eva_vit_model import EVAVisionTransformer
from .modified_resnet import ModifiedResNet
from .transformer import LayerNorm, QuickGELU, TextTransformer, VisionTransformer
try:
from apex.normalization import FusedLayerNorm
except ImportError:
FusedLayerNorm = LayerNorm
print("Please 'pip install apex'")
try:
import xformers.ops as xops
except ImportError:
xops = None
print("Please 'pip install xformers'")
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0.0
global_average_pool: bool = (
False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
)
drop_path_rate: Optional[float] = None # drop path rate
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
qkv_bias: bool = True
fusedLN: bool = False
xattn: bool = False
postnorm: bool = False
rope: bool = False
pt_hw_seq_len: int = 16 # 224/14
intp_freq: bool = False
naiveswiglu: bool = False
subln: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = "mlp"
pooler_type: str = "mean_pooler"
masked_language_modeling: bool = False
fusedLN: bool = False
xattn: bool = False
attn_mask: bool = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == "bf16":
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(
embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.eva_model_name:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNorm
visual = EVAVisionTransformer(
img_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
num_classes=embed_dim,
use_mean_pooling=vision_cfg.global_average_pool, # False
init_values=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
embed_dim=vision_cfg.width,
depth=vision_cfg.layers,
num_heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
qkv_bias=vision_cfg.qkv_bias,
drop_path_rate=vision_cfg.drop_path_rate,
norm_layer=partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
xattn=vision_cfg.xattn,
rope=vision_cfg.rope,
postnorm=vision_cfg.postnorm,
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
intp_freq=vision_cfg.intp_freq,
naiveswiglu=vision_cfg.naiveswiglu,
subln=vision_cfg.subln,
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
global_average_pool=vision_cfg.global_average_pool,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
tokenizer_name=text_cfg.hf_tokenizer_name,
proj=text_cfg.proj,
pooler_type=text_cfg.pooler_type,
masked_language_modeling=text_cfg.masked_language_modeling,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer,
xattn=text_cfg.xattn,
attn_mask=text_cfg.attn_mask,
)
return text
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
class CustomCLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
itm_task: bool = False,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if "text_projection" in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(
k.startswith(p)
for p in (
"text_projection",
"positional_embedding",
"token_embedding",
"transformer",
"ln_final",
"logit_scale",
)
):
k = "text." + k
new_state_dict[k] = v
return new_state_dict
return state_dict
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict(
[
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion)),
]
)
)
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
import hashlib
import os
import urllib
import warnings
from typing import Dict, Union
from tqdm import tqdm
try:
from huggingface_hub import hf_hub_download
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url="", hf_hub="", filename="", mean=None, std=None):
return dict(
url=url,
hf_hub=hf_hub,
mean=mean,
std=std,
)
_VITB32 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"
),
laion2b_e16=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"
),
laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"),
)
_VITB32_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"
),
)
_VITB16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"
),
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"),
)
_EVAB16 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
),
)
_VITL14 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"
),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"
),
laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
)
_EVAL14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
)
_VITL14_336 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
),
)
_EVAL14_336 = dict(
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"),
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"),
)
_EVAg14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
)
_EVAg14_PLUS = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"),
)
_EVAbigE14 = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
)
_EVAbigE14_PLUS = dict(
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
)
_PRETRAINED = {
# "ViT-B-32": _VITB32,
"OpenaiCLIP-B-32": _VITB32,
"OpenCLIP-B-32": _VITB32,
# "ViT-B-32-quickgelu": _VITB32_quickgelu,
"OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
"OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
# "ViT-B-16": _VITB16,
"OpenaiCLIP-B-16": _VITB16,
"OpenCLIP-B-16": _VITB16,
"EVA02-B-16": _EVAB16,
"EVA02-CLIP-B-16": _EVAB16,
# "ViT-B-16-plus-240": _VITB16_PLUS_240,
"OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
# "ViT-L-14": _VITL14,
"OpenaiCLIP-L-14": _VITL14,
"OpenCLIP-L-14": _VITL14,
"EVA02-L-14": _EVAL14,
"EVA02-CLIP-L-14": _EVAL14,
# "ViT-L-14-336": _VITL14_336,
"OpenaiCLIP-L-14-336": _VITL14_336,
"EVA02-CLIP-L-14-336": _EVAL14_336,
# "ViT-H-14": _VITH14,
# "ViT-g-14": _VITg14,
"OpenCLIP-H-14": _VITH14,
"OpenCLIP-g-14": _VITg14,
"EVA01-CLIP-g-14": _EVAg14,
"EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
# "ViT-bigG-14": _VITbigG14,
"OpenCLIP-bigG-14": _VITbigG14,
"EVA02-CLIP-bigE-14": _EVAbigE14,
"EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace("-", "_")
def list_pretrained_tags_by_model(model: str):
"""return all pretrain tags for the specified model architecture"""
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if "openaipublic" in url:
expected_sha256 = url.split("/")[-2]
elif "mlfoundations" in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ""
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(
expected_sha256
):
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
"Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`."
)
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = "open_clip_pytorch_model.bin",
revision=None,
cache_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
):
target = ""
if not cfg:
return target
download_url = cfg.get("url", "")
download_hf_hub = cfg.get("hf_hub", "")
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ""
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target
import logging
from math import pi
import torch
from einops import rearrange, repeat
from torch import nn
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
patch_dropout=0.0,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs = torch.einsum("..., f -> ... f", t, freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.patch_dropout = patch_dropout
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
def forward(self, t, patch_indices_keep=None):
if patch_indices_keep is not None:
batch = t.size()[0]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
return t * freqs_cos + rotate_half(t) * freqs_sin
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomResizedCrop,
Resize,
ToTensor,
)
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
class ResizeMaxSize(nn.Module):
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == "min" else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
return img
def _convert_to_rgb(image):
return image.convert("RGB")
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose(
[
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
]
)
else:
if resize_longest_max:
transforms = [ResizeMaxSize(image_size, fill=fill_color)]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend(
[
_convert_to_rgb,
ToTensor(),
normalize,
]
)
return Compose(transforms)
This diff is collapsed.
import collections.abc
import logging
import math
from itertools import repeat
import torch
import torch.nn.functional as F
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
# open CLIP
def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get("visual.positional_embedding", None)
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict["visual.positional_embedding"] = new_pos_embed
def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get("positional_embedding", None)
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict["positional_embedding"] = new_pos_embed
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# interpolate position embedding
if "visual.pos_embed" in state_dict:
pos_embed_checkpoint = state_dict["visual.pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict["visual.pos_embed"] = new_pos_embed
patch_embed_proj = state_dict["visual.patch_embed.proj.weight"]
patch_size = model.visual.patch_embed.patch_size
state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(
patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False
)
def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
# interpolate position embedding
if "pos_embed" in state_dict:
pos_embed_checkpoint = state_dict["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict["pos_embed"] = new_pos_embed
patch_embed_proj = state_dict["patch_embed.proj.weight"]
patch_size = model.visual.patch_embed.patch_size
state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(
patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False
)
def freeze_batch_norm_2d(module, module_match={}, name=""):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = ".".join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
# Adapted from https://github.com/ToTheBeginning/PuLID
import logging
from typing import Any, Dict, Optional, Union
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
logger = logging.getLogger(__name__)
def pulid_forward(
self,
hidden_states: torch.Tensor,
id_embeddings=None,
id_weight=None,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
nunchaku_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = nunchaku_block(
hidden_states=hidden_states,
id_embeddings=id_embeddings,
id_weight=id_weight,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
This diff is collapsed.
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
from ..._C.ops import gemm_awq, gemv_awq from ..._C.ops import gemm_awq, gemv_awq
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
__all__ = ["W4Linear"] __all__ = ["W4Linear"]
......
from .transformer_flux import NunchakuFluxTransformer2dModel from .transformer_flux import NunchakuFluxTransformer2dModel
from .transformer_sana import NunchakuSanaTransformer2DModel from .transformer_sana import NunchakuSanaTransformer2DModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel"]
This diff is collapsed.
from .pipeline_flux_pulid import PuLIDFluxPipeline
__all__ = ["PuLIDFluxPipeline"]
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment