# Hunyuan CLIP 模型初始化 import sys import tqdm import argparse from PIL import Image import copy import json import os import numpy as np sys.path.append(os.getcwd()) from transformers import BertTokenizer import torch from torchvision.transforms import ( Compose, ToTensor, Normalize, Resize, InterpolationMode, ) from collections import OrderedDict from typing import Tuple, Union from itertools import repeat import collections.abc import math import logging import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint import importlib.util # if importlib.util.find_spec('flash_attn'): # FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential( OrderedDict( [ ("-1", nn.AvgPool2d(stride)), ( "0", nn.Conv2d( inplanes, planes * self.expansion, 1, stride=1, bias=False, ), ), ("1", nn.BatchNorm2d(planes * self.expansion)), ] ) ) def forward(self, x: torch.Tensor): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class AttentionPool2d(nn.Module): def __init__( self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None ): super().__init__() self.positional_embedding = nn.Parameter( torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 ) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( 2, 0, 1 ) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat( [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] ), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False, ) return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): super().__init__() self.output_dim = output_dim self.input_resolution = input_resolution # the 3-layer stem self.conv1 = nn.Conv2d( 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(width // 2) self.conv2 = nn.Conv2d( width // 2, width // 2, kernel_size=3, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(width // 2) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.avgpool = nn.AvgPool2d(2) self.relu = nn.ReLU(inplace=True) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d( input_resolution // 32, embed_dim, heads, output_dim ) def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): # FIXME support for non-transformer pass def forward(self, x): def stem(x): for conv, bn in [ (self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3), ]: x = self.relu(bn(conv(x))) x = self.avgpool(x) return x x = x.type(self.conv1.weight.dtype) x = stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x) return x class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False, ): super().__init__() # self.attn = nn.MultiheadAttention(d_model, n_head) if not use_flash_attention else FlashMHA(d_model, n_head) self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential( OrderedDict( [ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)), ] ) ) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.use_flash_attention = use_flash_attention def attention(self, x: torch.Tensor): self.attn_mask = ( self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None ) if self.use_flash_attention: # Batch first is needed for FlashAttention. See https://github.com/HazyResearch/flash-attention/issues/84 for more information. return self.attn(x.transpose(1, 0))[0].transpose(1, 0) else: return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False, ): super().__init__() self.width = width self.layers = layers self.grad_checkpointing = False self.resblocks = nn.Sequential( *[ ResidualAttentionBlock(width, heads, attn_mask, use_flash_attention) for _ in range(layers) ] ) def forward(self, x: torch.Tensor): if self.grad_checkpointing and not torch.jit.is_scripting(): for r in self.resblocks: x = checkpoint(r, x) return x return self.resblocks(x) class VisualTransformer(nn.Module): def __init__( self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, use_flash_attention: bool = False, ): super().__init__() self.input_resolution = input_resolution self.grid_size = ( self.input_resolution // patch_size, self.input_resolution // patch_size, ) self.output_dim = output_dim self.conv1 = nn.Conv2d( in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False, ) scale = width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter( scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) ) self.ln_pre = LayerNorm(width) self.transformer = Transformer( width, layers, heads, use_flash_attention=use_flash_attention ) self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def random_masking(self, x, mask_ratio): N, L, D = x.shape # batch, length, dim len_keep = int((L - 1) * (1 - mask_ratio)) noise = torch.rand(N, L - 1, device=x.device) ids_shuffle = torch.argsort(noise, dim=1) + torch.ones( N, L - 1, device=x.device, dtype=int ) ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) x0 = x[:, 0, :] x0 = x0.reshape(N, 1, D) x_masked_add = torch.cat([x0, x_masked], axis=1) return x_masked_add def forward(self, x: torch.Tensor, mask_ratio: float = 0.0): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat( [ self.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ), x, ], dim=1, ) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) if mask_ratio != 0: x = self.random_masking(x, mask_ratio) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_post(x[:, 0, :]) if self.proj is not None: x = x @ self.proj return x class CLIP(nn.Module): def __init__( self, embed_dim: int, # vision image_resolution: int, vision_layers: Union[Tuple[int, int, int, int], int], vision_width: int, vision_patch_size: int, # text vocab_size: int, text_attention_probs_dropout_prob: float, text_hidden_act: str, text_hidden_dropout_prob: float, text_hidden_size: int, text_initializer_range: float, text_intermediate_size: int, text_max_position_embeddings: int, text_num_attention_heads: int, text_num_hidden_layers: int, text_type_vocab_size: int, # tokenizer = _tokenizer, # vision head width, added this param for ViT-H vision_head_width: int = 64, use_flash_attention: bool = False, args={}, ): super().__init__() print("use_flash_attention", use_flash_attention) if isinstance(vision_layers, (tuple, list)): vision_heads = vision_width * 32 // vision_head_width self.visual = ModifiedResNet( layers=vision_layers, output_dim=embed_dim, heads=vision_heads, input_resolution=image_resolution, width=vision_width, ) else: vision_heads = vision_width // vision_head_width self.visual = VisualTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, output_dim=embed_dim, use_flash_attention=use_flash_attention, ) # load isogclr loss tempNetwork if ( "contrastive_loss_type" in args and args.contrastive_loss_type == "isogclr_loss" ): feat_dim = args.isogclr_loss_feat_dim squeeze_dim = args.isogclr_loss_squeeze_dim tau_min, tau_max = args.isogclr_loss_tau_min, args.isogclr_loss_tau_max if args.isogclr_loss_temp_input == "unimodal": self.image_temp_gen = TempGenerator( feature_dim=feat_dim, M=squeeze_dim, tau_min=tau_min, tau_max=tau_max, ) self.text_temp_gen = TempGenerator( feature_dim=feat_dim, M=squeeze_dim, tau_min=tau_min, tau_max=tau_max, ) else: self.temp_gen = TempGenerator( feature_dim=feat_dim, M=squeeze_dim, tau_min=tau_min, tau_max=tau_max, ) else: self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.args = args self.initialize_parameters() def initialize_parameters(self): # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) if isinstance(self.visual, ModifiedResNet): if self.visual.attnpool is not None: std = self.visual.attnpool.c_proj.in_features**-0.5 nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) for resnet_block in [ self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4, ]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.bert.set_grad_checkpointing(enable) @property def dtype(self): return self.visual.conv1.weight.dtype def encode_image(self, image, mask_ratio=0): if isinstance(self.visual, ModifiedResNet): # mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer. return self.visual(image.type(self.dtype)) return self.visual(image.type(self.dtype), mask_ratio) def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() if p.grad: p.grad.data = p.grad.data.float() def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v", ]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() if isinstance(l, BertModel): l.to(torch.half) for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16) def convert_state_dict(state_dict): """Adapt to Flash Attention""" if not state_dict: return state_dict prefix = "module." if list(state_dict.keys())[0].startswith("module") else "" if f"{prefix}visual.transformer.resblocks.0.attn.in_proj_weight" in state_dict: for k in list(state_dict.keys()): if "attn.in_proj_weight" in k: state_dict[k.replace("attn.in_proj_weight", "attn.Wqkv.weight")] = ( state_dict.pop(k) ) elif "attn.in_proj_bias" in k: state_dict[k.replace("attn.in_proj_bias", "attn.Wqkv.bias")] = ( state_dict.pop(k) ) elif f"{prefix}visual.transformer.resblocks.0.attn.Wqkv.weight" in state_dict: for k in list(state_dict.keys()): if "attn.Wqkv.weight" in k: state_dict[k.replace("attn.Wqkv.weight", "attn.in_proj_weight")] = ( state_dict.pop(k) ) elif "attn.Wqkv.bias" in k: state_dict[k.replace("attn.Wqkv.bias", "attn.in_proj_bias")] = ( state_dict.pop(k) ) if f"{prefix}bert.encoder.layer.0.attention.self.query.weight" in state_dict: i = 0 while ( f"{prefix}bert.encoder.layer.{i}.attention.self.query.weight" in state_dict ): state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight"] = ( torch.cat( ( state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.query.weight" ), state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.key.weight" ), state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.value.weight" ), ) ) ) state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias"] = ( torch.cat( ( state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.query.bias" ), state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.key.bias" ), state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.value.bias" ), ) ) ) state_dict[ f"{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight" ] = state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.output.dense.weight" ) state_dict[ f"{prefix}bert.encoder.layer.{i}.attention.self.out_proj.bias" ] = state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.output.dense.bias" ) i += 1 elif f"{prefix}bert.encoder.layer.0.attention.self.Wqkv.weight" in state_dict: i = 0 while ( f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight" in state_dict ): ( state_dict[ f"{prefix}bert.encoder.layer.{i}.attention.self.query.weight" ], state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.key.weight"], state_dict[ f"{prefix}bert.encoder.layer.{i}.attention.self.value.weight" ], ) = torch.chunk( state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight" ), chunks=3, ) ( state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.query.bias"], state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.key.bias"], state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.value.bias"], ) = torch.chunk( state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias" ), chunks=3, ) state_dict[ f"{prefix}bert.encoder.layer.{i}.attention.output.dense.weight" ] = state_dict.pop( f"{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight" ) state_dict[ f"{prefix}bert.encoder.layer.{i}.attention.output.dense.bias" ] = state_dict.pop( f"module.bert.encoder.layer.{i}.attention.self.out_proj.bias" ) i += 1 return state_dict # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = lambda n, x: _ntuple(n)(x) def _convert_to_rgb(image): return image.convert("RGB") def image_transform(image_size=224): transform = Compose( [ Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), _convert_to_rgb, ToTensor(), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] ) return transform image_preprocess = image_transform(224) def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() if p.grad: p.grad.data = p.grad.data.float() class ImgClipEmbDetector: def __init__(self): # 用于预定于的多分类 self.image_preprocess = image_preprocess config = dict() config["vision_config"] = "ipadapter/model_configs/ViT-H-14.json" config["gpu"] = torch.device("cuda" if torch.cuda.is_available() else "cpu") config["text_config"] = ( "ipadapter/model_configs/RoBERTa-wwm-ext-large-cn-en.json" ) config["precision"] = "amp" config["context_length"] = 77 config["resume"] = "ckpts/t2i/clip_img_encoder/clip_img_encoder.pt" self.cfg = config self.model = self.build_model() def build_model(self): with open(self.cfg["vision_config"], "r") as fv, open( self.cfg["text_config"], "r" ) as ft: model_info = json.load(fv) if isinstance(model_info["vision_layers"], str): model_info["vision_layers"] = eval(model_info["vision_layers"]) for k, v in json.load(ft).items(): model_info[k] = v model = CLIP(**model_info) if self.cfg["precision"] == "fp16": convert_weights(model) # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 if self.cfg["precision"] == "amp" or self.cfg["precision"] == "fp32": convert_models_to_fp32(model) if self.cfg["precision"] == "fp16": convert_weights(model) checkpoint = torch.load(self.cfg["resume"], map_location="cpu") sd = checkpoint["state_dict"] if next(iter(sd.items()))[0].startswith("module"): sd = { k[len("module.") :]: v for k, v in sd.items() if "bert.pooler" not in k } model.load_state_dict(sd, strict=False) print( f"=> loaded checkpoint {self.cfg['resume']} (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)" ) model.eval() loc = self.cfg["gpu"] model.to(loc) self.cfg["device"] = loc return model def __call__(self, images): # texts: must be list-of-str with torch.no_grad(): image_features = self.model.encode_image(images) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features