Commit 6f43e8fa authored by mashun1's avatar mashun1
Browse files

open_clip

parents
Pipeline #1689 canceled with stages
pytest-split==0.8.0
pytest==7.2.0
transformers[sentencepiece]
# timm>=1.0.7
# torch>=1.9.0
# torchvision
webdataset>=0.2.5
regex
ftfy
tqdm
pandas
braceexpand
huggingface_hub
transformers[sentencepiece]
# timm>=1.0.7
fsspec
# eval on a single gpu
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m open_clip_train.main \
--model ViT-L-16-CL32-GAP \
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
\ No newline at end of file
CUDA_VISIBLE_DEVICES=1 python3 -m open_clip_train.main \
--model ViT-H-14-CL32-GAP-BigVision \
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \
--force-image-size 336 \
--square-resize-only \
--interpolation 'bilinear' \
--image-mean 0.485 0.456 0.406 \
--image-std 0.229 0.224 0.225 \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
#!/bin/bash
export HIP_VISIBLE_DEVICES=4
csv_path="datasets/train2014.csv"
python -m open_clip_train.main \
--dataset-type "csv" \
--train-data ${csv_path} \
--warmup 1000 \
--batch-size 32 \
--lr 1e-5 \
--wd 0.1 \
--epochs 1 \
--workers 3 \
--model "coca_ViT-L-14" \
--coca-contrastive-loss-weight 0 \
--coca-caption-loss-weight 1 \
--log-every-n-steps 100
\ No newline at end of file
# 64k batchsize for 2.048e-3 lr
TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m open_clip_train.main \
--save-frequency 1 \
--save-most-recent \
--zeroshot-frequency 1 \
--train-data '/path/to/laion' \
--dataset-type webdataset \
--lr "2.048e-3" \
--beta1 0.9 \
--beta2 0.95 \
--warmup 782 \
--wd 0.2 \
--batch-size 4096 \
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
--epochs=7 \
--workers=6 \
--model ViT-H-14-CL32-GAP \
--precision 'amp_bf16' \
--local-loss \
--gather-with-grad \
--force-image-size 224 \
--grad-checkpointing \
--log-every-n-steps 32 \
--seed 0 \
--logs ./logs/ \
--imagenet-val '/path/to/ImageNet/val' \
--name 'name' \
--report-to "wandb" \
--wandb-project-name "project_name"
# 64k batchsize for 2.048e-3 lr
TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m open_clip_train.main \
--save-frequency 1 \
--save-most-recent \
--zeroshot-frequency 1 \
--train-data '/path/to/laion' \
--dataset-type webdataset \
--lr "2.048e-3" \
--beta1 0.9 \
--beta2 0.95 \
--warmup 782 \
--wd 0.2 \
--batch-size 4096 \
--aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \
--epochs=7 \
--workers=6 \
--model ViT-H-14-CL8-SyntaxMask-GAP \
--precision 'amp_bf16' \
--local-loss \
--gather-with-grad \
--force-image-size 84 \
--grad-checkpointing \
--log-every-n-steps 32 \
--seed 0 \
--logs ./logs/ \
--imagenet-val '/path/to/ImageNet/val' \
--name 'name' \
--report-to "wandb" \
--wandb-project-name "project_name"
#!/bin/bash
export HIP_VISIBLE_DEVICES=1
torchrun --nproc_per_node 4 -m open_clip_train.main \
--train-data '/home/datasets/clip/mscoco/{00000..00059}.tar' \
--train-num-samples 100000 \
--dataset-type webdataset \
--batch-size 320 \
--precision amp_bf16 \
--workers 4 \
--imagenet-val "/home/datasets/imagenet2012/val"
\ No newline at end of file
from .version import __version__
from .coca_model import CoCa
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
from .tokenizer import SimpleTokenizer, tokenize, decode
from .transform import image_transform, AugmentationCfg
from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
try:
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StopStringCriteria,
EosTokenCriteria,
StoppingCriteriaList
)
GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return decoder
def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
if not isinstance(token_id, torch.Tensor):
if isinstance(token_id, int):
token_id = [token_id]
token_id = torch.tensor(token_id, device=device)
return token_id
class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)
self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
self.context_length = multimodal_cfg.context_length
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
def _encode_image(self, images, normalize: bool = True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs
def _encode_text(self, text, normalize: bool = True):
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
def encode_image(self, images, normalize: bool = True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent
def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent
def forward(
self,
image,
text: Optional[torch.Tensor] = None,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
output_labels: bool = True,
):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
if text is None:
return {"image_features": image_latent, "image_embs": image_embs}
text_latent, token_embs = self._encode_text(text)
# FIXME this isn't an ideal solution, would like to improve -RW
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
if output_labels:
# align text_embs and thus logits with labels for teacher-forcing caption loss
token_embs = token_embs[:, :-1]
logits = self.text_decoder(image_embs, token_embs)
out_dict = {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"logit_scale": self.logit_scale.exp()
}
if labels is not None:
out_dict["labels"] = labels
if self.logit_bias is not None:
out_dict["logit_bias"] = self.logit_bias
return out_dict
def generate(
self,
image,
text=None,
seq_len=30,
max_seq_len=77,
temperature=1.,
generation_type="beam_search",
top_p=0.1, # keep tokens in the 1 - top_p quantile
top_k=1, # keeps the top_k most probable tokens
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
device = image.device
with torch.no_grad():
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
stopping_criteria = StoppingCriteriaList(stopping_criteria)
if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs=image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
pad_len = seq_len - output.shape[1]
return torch.cat((
output,
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
),
dim=1
)
return output
elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)
image_latent, image_embs = self._encode_image(image)
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)
if num_dims == 1:
text = text[None, :]
self.eval()
out = text
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(
image,
x,
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)
if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
cur_len += 1
if all(stopping_criteria(out, None)):
break
if num_dims == 1:
out = out.squeeze(0)
self.train(was_training)
return out
def _generate_beamsearch(
self,
image_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
break
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
"""
from typing import Union
import torch
import numpy as np
from .model import CLIP, CustomTextCLIP
from .transformer import TextTransformer, Transformer
@torch.no_grad()
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
""" Load weights from .npz checkpoints for official Google big_vision image-text models
Currently the SigLIP source models are supported and a CustomTextCLIP destination model
w/ timm image encoder.
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False
def _convert_timm_img(module, prefix):
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
module.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.patch_embed.proj.weight.copy_(embed_conv_w)
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
if module.cls_token is not None:
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
if pos_embed_w.shape != module.pos_embed.shape:
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=module.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.pos_embed.copy_(pos_embed_w)
mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if module.attn_pool is not None:
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
module.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
module.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
for r in range(2):
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
def _convert_openclip_transformer(module: Transformer, prefix):
for i, block in enumerate(module.resblocks.children()):
block_prefix = f'{prefix}encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.in_proj_weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.in_proj_bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
def _convert_openclip_txt(module: TextTransformer, prefix):
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
module.positional_embedding.copy_(pos_embed_w)
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])
@torch.no_grad()
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
def _convert_timm_img(state_dict):
if fastvit:
from timm.models.fastvit import checkpoint_filter_fn
else:
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
return timm_state_dict
def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
text_dict = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('projection_layer', 'text_projection')
k = k.replace('embedding_layer', 'token_embedding')
if k.startswith('positional_embedding.pos_embed.pos_embed'):
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
v = v.squeeze()
k = k.replace('final_layer_norm', 'ln_final')
k = k.replace('pre_norm_mha.0', 'ln_1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'ln_2')
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
k = k.replace('qkv_proj.weight', 'in_proj_weight')
k = k.replace('qkv_proj.bias', 'in_proj_bias')
k = k.replace('transformer.', 'transformer.resblocks.')
text_dict['text.' + k] = v
return text_dict
image_dict = _convert_timm_img(state_dict)
text_dict = _convert_openclip_txt(state_dict)
out_dict = {**image_dict, **text_dict}
out_dict['logit_scale'] = state_dict['logit_scale']
return out_dict
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
state_dict = convert_mobile_clip_state_dict(model, state_dict)
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
# convert b model
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
return state_dict
import json
import logging
import os
import re
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .convert import convert_state_dict
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def _get_hf_config(model_id, cache_dir=None):
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_tokenizer(
model_name: str = '',
context_length: Optional[int] = None,
**kwargs,
):
if model_name.startswith(HF_HUB_PREFIX):
model_name = model_name[len(HF_HUB_PREFIX):]
try:
config = _get_hf_config(model_name)['model_cfg']
except Exception:
tokenizer = HFTokenizer(
model_name,
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
**kwargs,
)
return tokenizer
else:
config = get_model_config(model_name)
assert config is not None, f"No valid model config found for {model_name}."
text_config = config.get('text_cfg', {})
if 'tokenizer_kwargs' in text_config:
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
else:
tokenizer_kwargs = kwargs
if context_length is None:
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
if 'hf_tokenizer_name' in text_config:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer(
context_length=context_length,
**tokenizer_kwargs,
)
return tokenizer
def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
state_dict = checkpoint.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.convert import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}
state_dict = load_state_dict(checkpoint_path)
# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)
# Detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
resize_text_pos_embed(state_dict, model)
# Finally, load the massaged state_dict into model
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
**model_kwargs,
):
force_preprocess_cfg = force_preprocess_cfg or {}
preprocess_cfg = asdict(PreprocessCfg())
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config = _get_hf_config(model_id, cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
model_cfg = config['model_cfg']
pretrained_hf = False # override, no need to load original HF text weights
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
cache_dir=cache_dir,
)
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
if custom_text:
if "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
from .transformer import LayerNormFp32
def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
# set image preprocessing configuration in model attributes for convenience
if getattr(model.visual, 'image_size', None) is not None:
# use image_size set on model creation (via config or force_image_size arg)
force_preprocess_cfg['size'] = model.visual.image_size
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
return model
def create_loss(args):
if args.distill:
return DistillClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.siglip:
assert not args.horovod, "Horovod not currently supported for SigLip"
return SigLipLoss(
rank=args.rank,
world_size=args.world_size,
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
**model_kwargs,
)
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
preprocess_train = image_transform_v2(
pp_cfg,
is_train=True,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform_v2(
pp_cfg,
is_train=False,
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
return_transform: bool = True,
cache_dir: Optional[str] = None,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
cache_dir=cache_dir,
require_pretrained=True,
**model_kwargs,
)
if not return_transform:
return model
preprocess = image_transform_v2(
PreprocessCfg(**model.visual.preprocess_cfg),
is_train=False,
)
return model, preprocess
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/bert
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
},
"pooler": "cls_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/m2m_100
"m2m_100": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "encoder_attention_heads",
"layers": "encoder_layers",
},
"pooler": "cls_pooler",
},
}
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
if (self.use_pooler_output and
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
(x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
@register_pooler
class ClsLastHiddenStatePooler(nn.Module):
"""CLS token pooling
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
"""
def __init__(self):
super().__init__()
self.cls_token_position = 0
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
output_tokens: torch.jit.Final[bool]
def __init__(
self,
model_name_or_path: str,
output_dim: int,
config: PretrainedConfig = None,
pooler_type: str = None,
proj_type: str = None,
pretrained: bool = True,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = (pooler_type == "cls_pooler")
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
pooler_type = (arch_dict[self.config.model_type]["pooler"])
# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
self.vocab_size = getattr(self.config, 'vocab_size', 0)
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj_type == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj_type == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
def forward(self, x: TensorType):
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
if self.output_tokens:
return projected, tokens
return projected
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment