# Make sure we are able to load base models as well as derived models (with heads)
start_prefix=""
model_to_load=model
load(model_to_load,prefix=start_prefix)
delstate_dict
iflen(error_msgs)>0:
error_msg="\n\t".join(error_msgs)
if"size mismatch"inerror_msg:
error_msg+="\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
raiseRuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
iflen(unexpected_keys)>0:
logging.warning(
f"Some weights of the model checkpoint at {pretrained_model_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logging.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
iflen(missing_keys)>0:
logging.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
eliflen(mismatched_keys)==0:
logging.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
iflen(mismatched_keys)>0:
mismatched_warning="\n".join([f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"forkey,shape1,shape2inmismatched_keys])
logging.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
error_str=f"Pretrained weights ({pretrained}) not found for model {model_name}."f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}."
logging.warning(error_str)
raiseRuntimeError(error_str)
else:
visual_checkpoint_path=""
text_checkpoint_path=""
ifpretrained_image:
pretrained_visual_model=pretrained_visual_model.replace("/","-")# for callers using old naming with / in ViT names
ls_init_value:Optional[float]=None# layer scale initial value
patch_dropout:float=0.0# what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
global_average_pool:bool=False# whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
drop_path_rate:Optional[float]=None# drop path rate
timm_model_name:str=None# a valid model name overrides layers, width, patch_size
timm_model_pretrained:bool=False# use (imagenet) pretrained weights for named model
timm_pool:str="avg"# feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj:str="linear"# linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias:bool=False# enable bias final projection
eva_model_name:str=None# a valid eva model name overrides layers, width, patch_size
qkv_bias:bool=True
fusedLN:bool=False
xattn:bool=False
postnorm:bool=False
rope:bool=False
pt_hw_seq_len:int=16# 224/14
intp_freq:bool=False
naiveswiglu:bool=False
subln:bool=False
use_rms_norm:bool=False
@dataclass
classCLIPTextCfg:
context_length:int=77
vocab_size:int=49408
width:int=512
heads:int=8
layers:int=12
ls_init_value:Optional[float]=None# layer scale initial value
# if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
# select_every_k_layer = len(image_features) // 4
# image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1)