Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union, Dict, Any
import torch
try:
import deepspeed
except ImportError:
deepspeed = None
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict, get_cast_dtype
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform
from .tokenizer import HFTokenizer, tokenize
from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = (".json",)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f"*{ext}"))
for cf in config_files:
with open(cf, "r", encoding="utf8") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
_rescan_model_configs() # initial populate of model config registry
def list_models():
"""enumerate available model architectures based on config files"""
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
"""add model config path or file and update registry"""
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def get_tokenizer(model_name):
config = get_model_config(model_name)
tokenizer = HFTokenizer(config["text_cfg"]["hf_tokenizer_name"]) if "hf_tokenizer_name" in config["text_cfg"] else tokenize
return tokenizer
# loading openai CLIP weights when is_openai=True for training
def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
if is_openai:
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
state_dict = model.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
for mk in model_key.split("|"):
if isinstance(checkpoint, dict) and mk in checkpoint:
state_dict = checkpoint[mk]
break
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith("module"):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for k in skip_list:
if k in list(state_dict.keys()):
logging.info(f"Removing key {k} from pretrained checkpoint")
del state_dict[k]
if os.getenv("RoPE") == "1":
for k in list(state_dict.keys()):
if "freqs_cos" in k or "freqs_sin" in k:
del state_dict[k]
return state_dict
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
# detect old format and make compatible with new format
if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"):
state_dict = convert_to_custom_text_state_dict(state_dict)
if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"):
state_dict["logit_scale"] = state_dict["text.logit_scale"]
del state_dict["text.logit_scale"]
# resize_clip_pos_embed for CLIP and open CLIP
if "visual.positional_embedding" in state_dict:
resize_clip_pos_embed(state_dict, model)
# specified to eva_vit_model
elif "visual.pos_embed" in state_dict:
resize_evaclip_pos_embed(state_dict, model)
# resize_clip_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return incompatible_keys
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if not k.startswith("visual."):
del state_dict[k]
for k in list(state_dict.keys()):
if k.startswith("visual."):
new_k = k[7:]
state_dict[new_k] = state_dict[k]
del state_dict[k]
return state_dict
def load_clip_text_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if k.startswith("visual."):
del state_dict[k]
return state_dict
def get_pretrained_tag(pretrained_model):
pretrained_model = pretrained_model.lower()
if "laion" in pretrained_model or "open_clip" in pretrained_model:
return "open_clip"
elif "openai" in pretrained_model:
return "clip"
elif "eva" in pretrained_model and "clip" in pretrained_model:
return "eva_clip"
else:
return "other"
def load_zero_partitions(model, state_dict, is_deepspeed_zero3_enabled, pretrained_model_path, ignore_mismatched_sizes=False):
"""
adept from pytorch lightning and transformers
with deepspeed.zero.Init():
model = MyModel()
state_dict = torch.load(model_path, map_location="cpu")
load_zero_partitions(model, prefix="")
"""
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
loaded_keys = list(state_dict.keys())
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append((checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape))
del state_dict[checkpoint_key]
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
load(model_to_load, prefix=start_prefix)
del state_dict
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
logging.warning(
f"Some weights of the model checkpoint at {pretrained_model_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logging.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logging.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logging.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join([f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys])
logging.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
def load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=True, visual_model=None, text_model=None, model_key="model|module|state_dict", skip_list=[]):
visual_tag = get_pretrained_tag(visual_model)
text_tag = get_pretrained_tag(text_model)
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
visual_incompatible_keys, text_incompatible_keys = None, None
if visual_checkpoint_path:
if visual_tag == "eva_clip" or visual_tag == "open_clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
elif visual_tag == "clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
# resize_clip_pos_embed for CLIP and open CLIP
if "positional_embedding" in visual_state_dict:
resize_visual_pos_embed(visual_state_dict, model)
# specified to EVA model
elif "pos_embed" in visual_state_dict:
resize_eva_pos_embed(visual_state_dict, model)
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
if text_checkpoint_path:
if text_tag == "eva_clip" or text_tag == "open_clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
elif text_tag == "clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
return visual_incompatible_keys, text_incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = "",
pretrained_text: str = "",
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == "openai":
logging.info(f"Loading pretrained {model_name} from OpenAI.")
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
else:
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f"Loaded {model_name} model config.")
else:
logging.error(f"Model config for {model_name} not found; available models {list_models()}.")
raise RuntimeError(f"Model config for {model_name} not found.")
if "rope" in model_cfg.get("vision_cfg", {}):
if model_cfg["vision_cfg"]["rope"]:
os.environ["RoPE"] = "1"
else:
os.environ["RoPE"] = "0"
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
cast_dtype = get_cast_dtype(precision)
custom_clip = model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"])
if custom_clip:
if "hf_model_name" in model_cfg.get("text_cfg", {}):
model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
pretrained_cfg = {}
if pretrained:
checkpoint_path = ""
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False)
else:
error_str = f"Pretrained weights ({pretrained}) not found for model {model_name}." f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}."
logging.warning(error_str)
raise RuntimeError(error_str)
else:
visual_checkpoint_path = ""
text_checkpoint_path = ""
if pretrained_image:
pretrained_visual_model = pretrained_visual_model.replace("/", "-") # for callers using old naming with / in ViT names
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
if "timm_model_name" in model_cfg.get("vision_cfg", {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg["vision_cfg"]["timm_model_pretrained"] = True
elif pretrained_image_cfg:
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_image):
visual_checkpoint_path = pretrained_image
else:
logging.warning(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
raise RuntimeError(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
if pretrained_text:
pretrained_text_model = pretrained_text_model.replace("/", "-") # for callers using old naming with / in ViT names
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
if pretrained_image_cfg:
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_text):
text_checkpoint_path = pretrained_text
else:
logging.warning(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
raise RuntimeError(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
if visual_checkpoint_path:
logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).")
if text_checkpoint_path:
logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).")
if visual_checkpoint_path or text_checkpoint_path:
load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=False, visual_model=pretrained_visual_model, text_model=pretrained_text_model, model_key="model|module|state_dict", skip_list=skip_list)
if "fp16" in precision or "bf16" in precision:
logging.info(f"convert precision to {precision}")
model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16)
# model.to(device=device)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD
if jit:
model = torch.jit.script(model)
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = "",
pretrained_text: str = "",
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_hf=pretrained_hf,
pretrained_visual_model=pretrained_visual_model,
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
skip_list=skip_list,
)
image_mean = image_mean or getattr(model.visual, "image_mean", None)
image_std = image_std or getattr(model.visual, "image_std", None)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: str,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
is_frozen: bool = False,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(f"{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}." f" Use open_clip.list_pretrained() to find one.")
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
cache_dir=cache_dir,
)
if is_frozen:
for param in model.parameters():
param.requires_grad = False
if not return_transform:
return model
image_mean = image_mean or getattr(model.visual, "image_mean", None)
image_std = image_std or getattr(model.visual, "image_std", None)
preprocess = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
return model, preprocess
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens",
},
"pooler": "mean_pooler",
},
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
}
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
if self.use_pooler_output and isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and (x.pooler_output is not None):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
def __init__(self, model_name_or_path: str, output_dim: int, tokenizer_name: str = None, config: PretrainedConfig = None, pooler_type: str = None, proj: str = None, pretrained: bool = True, masked_language_modeling: bool = False):
super().__init__()
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = pooler_type == "cls_pooler"
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
if masked_language_modeling:
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (AutoModelForMaskedLM.from_config, self.config)
else:
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
if masked_language_modeling:
self.transformer = AutoModelForMaskedLM.from_config(config)
else:
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == "linear":
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == "mlp":
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
# attn_mask = (x != self.config.pad_token_id).long()
# out = self.transformer(
# input_ids=x,
# attention_mask=attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# )
# pooled_out = self.pooler(out, attn_mask)
# return self.itm_proj(pooled_out)
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
labels = input_ids.clone()
attn_mask = (input_ids != self.config.pad_token_id).long()
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(input_ids.device)
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
probability_matrix = torch.full(labels.shape, mlm_probability)
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix)
mlm_output = self.transformer(
input_ids,
attention_mask=attn_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
labels=labels,
)
return mlm_output.loss
# mlm_output = self.transformer(input_ids,
# attention_mask = attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# return_dict = True,
# ).last_hidden_state
# logits = self.mlm_proj(mlm_output)
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
# labels = labels[:, 1:].contiguous().view(-1)
# mlm_loss = F.cross_entropy(
# logits,
# labels,
# # label_smoothing=0.1,
# )
# return mlm_loss
def forward(self, x: TensorType) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
return self.proj(pooled_out)
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def get_num_layers(self):
encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
return len(layer_list)
def init_parameters(self):
pass
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from timm.loss import LabelSmoothingCrossEntropy
def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False):
assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support."
if use_horovod:
assert hvd is not None, "Please install horovod"
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
# all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
# all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
class ClipLoss(nn.Module):
def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
smoothing=0.0,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
# cache state
self.prev_num_logits = 0
self.labels = {}
def forward(self, image_features, text_features, logit_scale=1.0):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
# calculated ground-truth and cache if enabled
num_logits = logits_per_image.shape[0]
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
if self.label_smoothing_cross_entropy:
total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2
else:
total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
acc = None
i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
acc = {"i2t": i2t_acc, "t2i": t2i_acc}
return total_loss, acc
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
try:
from .hf_model import HFTextEncoder
except:
HFTextEncoder = None
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
try:
from apex.normalization import FusedLayerNorm
except:
FusedLayerNorm = LayerNorm
# print("Please 'pip install apex'")
try:
import xformers.ops as xops
except ImportError:
xops = None
# print("Please 'pip install xformers'")
class RMSnorm(nn.Module):
"""
adepted from transformers T5LayerNorm
"""
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
drop_path_rate: Optional[float] = None # drop path rate
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
qkv_bias: bool = True
fusedLN: bool = False
xattn: bool = False
postnorm: bool = False
rope: bool = False
pt_hw_seq_len: int = 16 # 224/14
intp_freq: bool = False
naiveswiglu: bool = False
subln: bool = False
use_rms_norm: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = "mlp"
pooler_type: str = "mean_pooler"
masked_language_modeling: bool = False
fusedLN: bool = False
xattn: bool = False
attn_mask: bool = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == "bf16":
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.eva_model_name:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = RMSnorm if vision_cfg.use_rms_norm else LayerNorm
visual = EVAVisionTransformer(
img_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
num_classes=embed_dim,
use_mean_pooling=vision_cfg.global_average_pool, # False
init_values=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
embed_dim=vision_cfg.width,
depth=vision_cfg.layers,
num_heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
qkv_bias=vision_cfg.qkv_bias,
drop_path_rate=vision_cfg.drop_path_rate,
norm_layer=partial(norm_layer, eps=1e-6),
xattn=vision_cfg.xattn,
rope=vision_cfg.rope,
postnorm=vision_cfg.postnorm,
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
intp_freq=vision_cfg.intp_freq,
naiveswiglu=vision_cfg.naiveswiglu,
subln=vision_cfg.subln,
)
elif vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, embed_dim=embed_dim, image_size=vision_cfg.image_size
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
global_average_pool=vision_cfg.global_average_pool,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(text_cfg.hf_model_name, output_dim=embed_dim, tokenizer_name=text_cfg.hf_tokenizer_name, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, masked_language_modeling=text_cfg.masked_language_modeling)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer,
xattn=text_cfg.xattn,
attn_mask=text_cfg.attn_mask,
)
return text
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {"logit_scale"}
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
class CustomCLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
itm_task: bool = False,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
@torch.jit.ignore
def no_weight_decay(self):
return {"logit_scale"}
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr, None)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, nn.Parameter):
l.data = l.data.to(dtype)
for name in ["text_projection", "proj"]:
if hasattr(l, name) and isinstance(l, nn.Parameter):
attr = getattr(l, name, None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if "text_projection" in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in ("text_projection", "positional_embedding", "token_embedding", "transformer", "ln_final", "logit_scale")):
k = "text." + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device("cpu")):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(model, inputs=dict(forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,)))
model.visual.image_size = image_size
return model
{
"embed_dim": 1536,
"vision_cfg": {
"image_size": 224,
"layers": 48,
"width": 5120,
"head_width": 128,
"mlp_ratio": 5,
"patch_size": 14,
"eva_model_name": "eva-clip-18b-14-x",
"drop_path_rate": 0,
"qkv_bias": false,
"xattn": true,
"postnorm": true,
"fusedLN": false,
"use_rms_norm": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": false
}
}
\ No newline at end of file
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 448,
"layers": 32,
"width": 4096,
"head_width": 128,
"mlp_ratio": 5,
"patch_size": 14,
"eva_model_name": "eva-clip-8b-14-plus-x",
"drop_path_rate": 0,
"qkv_bias": false,
"xattn": true,
"postnorm": false,
"fusedLN": false,
"use_rms_norm": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": false
}
}
\ No newline at end of file
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 4096,
"head_width": 128,
"mlp_ratio": 5,
"patch_size": 14,
"eva_model_name": "eva-clip-8b-14-x",
"drop_path_rate": 0,
"qkv_bias": false,
"xattn": true,
"postnorm": false,
"fusedLN": false,
"use_rms_norm": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": false
}
}
\ No newline at end of file
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16,
"eva_model_name": "eva-clip-b-16",
"ls_init_value": 0.1,
"drop_path_rate": 0.0
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0.4,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"head_width": 64,
"patch_size": 16,
"mlp_ratio": 2.6667,
"eva_model_name": "eva-clip-b-16-X",
"drop_path_rate": 0.0,
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"xattn": true,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 64,
"width": 1792,
"head_width": 112,
"mlp_ratio": 8.571428571428571,
"patch_size": 14,
"eva_model_name": "eva-clip-4b-14-x",
"drop_path_rate": 0,
"xattn": true,
"postnorm": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": true
}
}
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 64,
"width": 1792,
"head_width": 112,
"mlp_ratio": 8.571428571428571,
"patch_size": 14,
"eva_model_name": "eva-clip-4b-14-x",
"drop_path_rate": 0,
"xattn": true,
"postnorm": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24,
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 448,
"layers": 77,
"width": 2304,
"head_width": 144,
"mlp_ratio": 10.9722,
"patch_size": 14,
"eva_model_name": "eva-clip-10b-14-x",
"drop_path_rate": 0,
"xattn": true,
"postnorm": false,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": true
}
}
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 77,
"width": 2304,
"head_width": 144,
"mlp_ratio": 10.9722,
"patch_size": 14,
"eva_model_name": "eva-clip-10b-14-x",
"drop_path_rate": 0,
"xattn": true,
"postnorm": false,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": true
}
}
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
from .utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, "partial locking not currently supported for this model"
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
import warnings
from typing import List, Optional, Union
import torch
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
__all__ = ["list_openai_models", "load_openai_model"]
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_models_by_tag("openai")
def load_openai_model(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
jit: bool = True,
cache_dir: Optional[str] = None,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
cache_dir : Optional[str]
The directory to cache the downloaded model weights
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if precision is None:
precision = "fp32" if device == "cpu" else "fp16"
if get_pretrained_url(name, "openai"):
model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir)
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype = get_cast_dtype(precision)
try:
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model = model.to(device)
if precision.startswith("amp") or precision == "fp32":
model.float()
elif precision == "bf16":
convert_weights_to_lp(model, dtype=torch.bfloat16)
return model
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 (typically for CPU)
if precision == "fp32":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
# ensure image_size attr available at consistent location for both jit and non-jit
model.visual.image_size = model.input_resolution.item()
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment