diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 86b4486..85fe0d4 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -14,7 +14,7 @@ from .convert import convert_state_dict from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg from .coca_model import CoCa -from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss +from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, DRClipLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf @@ -344,6 +344,10 @@ def create_loss(args): rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, + dist_logit_scale=args.distill_logit_scale, + teacher_dimension=args.distill_teacher_dimension, + distill_loss_weights=args.distill_loss_weights, + average_after_softmax=args.distill_average_after_softmax, ) elif "coca" in args.model.lower(): return CoCaLoss( @@ -362,6 +366,19 @@ def create_loss(args): rank=args.rank, world_size=args.world_size, ) + elif args.dataset_reinforcement: + return DRClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + dist_logit_scale=args.distill_logit_scale, + teacher_dimension=args.distill_teacher_dimension, + distill_loss_weights=args.distill_loss_weights, + average_after_softmax=args.distill_average_after_softmax, + ) return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 5beaab1..b40fdb9 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +import numpy as np try: import torch.distributed.nn @@ -177,10 +178,59 @@ class CoCaLoss(ClipLoss): return clip_loss, caption_loss +def dot_ensemble_features(feat_a, feat_b, logit_scale, dims): + """Compute sum_t Softmax(a_t @ b_t) for between features from an ensemble model.""" + num_members = len(dims) + dims = np.cumsum([0] + dims) + logits = [ + logit_scale * (feat_a[:, dims[i]:dims[i+1]] @ feat_b[dims[i]:dims[i+1], :]) + for i in range(num_members) + ] + logits = sum([F.softmax(logit, dim=1) for logit in logits]) / num_members + return logits + + class DistillClipLoss(ClipLoss): + def __init__( + self, + *args, + teacher_dimension=[-1], + distill_loss_weights=[1.0, 1.0], + average_after_softmax=False, + dist_logit_scale=None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.dist_logit_scale = dist_logit_scale + self.teacher_dimension = teacher_dimension + self.distill_loss_weights = distill_loss_weights + self.average_after_softmax = average_after_softmax + + def get_logits_dist(self, image_features, text_features, logit_scale): + dims = self.teacher_dimension + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = dot_ensemble_features(image_features, all_text_features.T, logit_scale, dims) + logits_per_text = dot_ensemble_features(text_features, all_image_features.T, logit_scale, dims) + else: + logits_per_image = dot_ensemble_features(all_image_features, all_text_features.T, logit_scale, dims) + logits_per_text = logits_per_image.T + else: + logits_per_image = dot_ensemble_features(image_features, text_features.T, logit_scale, dims) + logits_per_text = dot_ensemble_features(text_features, image_features.T, logit_scale, dims) + + return logits_per_image, logits_per_text + def dist_loss(self, teacher_logits, student_logits): - return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + if self.average_after_softmax: + return -(teacher_logits * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + else: + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) def forward( self, @@ -189,30 +239,82 @@ class DistillClipLoss(ClipLoss): logit_scale, dist_image_features, dist_text_features, - dist_logit_scale, + dist_logit_scale=None, output_dict=False, ): logits_per_image, logits_per_text = \ self.get_logits(image_features, text_features, logit_scale) - dist_logits_per_image, dist_logits_per_text = \ - self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + if self.dist_logit_scale is not None: + dist_logit_scale = self.dist_logit_scale + + if self.average_after_softmax: + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits_dist(dist_image_features, dist_text_features, dist_logit_scale) + else: + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) contrastive_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) - ) / 2 + ) / 2 * self.distill_loss_weights[0] distill_loss = ( self.dist_loss(dist_logits_per_image, logits_per_image) + self.dist_loss(dist_logits_per_text, logits_per_text) - ) / 2 + ) / 2 * self.distill_loss_weights[1] + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + return contrastive_loss, distill_loss + + +class DRClipLoss(DistillClipLoss): + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + syn_text_features=None, + dist_syn_text_features=None, + output_dict=False, + ): + loss_gt = super().forward( + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + output_dict=output_dict, + ) + if syn_text_features is None: + return loss_gt + + loss_syn = super().forward( + image_features, + syn_text_features, + logit_scale, + dist_image_features, + dist_syn_text_features, + output_dict=output_dict, + ) if output_dict: + contrastive_loss = ( + loss_gt["contrastive_loss"] + loss_syn["contrastive_loss"] + ) + distill_loss = ( + loss_gt["distill_loss"] + loss_syn["distill_loss"] + ) return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + contrastive_loss = loss_gt[0] + loss_syn[0] + distill_loss = loss_gt[1] + loss_syn[1] return contrastive_loss, distill_loss diff --git a/src/training/data.py b/src/training/data.py index 07b9fee..41f868c 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -19,6 +19,7 @@ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableD from torch.utils.data.distributed import DistributedSampler from webdataset.filters import _shuffle from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample +from training.dr.transforms import compose_from_config try: import horovod.torch as hvd @@ -183,6 +184,63 @@ def log_and_continue(exn): return True +def preprocess_dr(sample, dr_transforms, tokenizer, mix_synthetic, mix_synthetic_ratio): + """Preprocess image, text, synthetic-captions, and DR CLIP embeddings.""" + # Preprocess image + # Sample an image augmentation + aug_idx = np.random.randint(0, len(sample["paug.json"]["param_aug"])) + params = sample["paug.json"]["param_aug"][aug_idx] + params = dr_transforms.decompress(params) + image = sample["image"].convert('RGB') + image, _ = dr_transforms.reapply(image, params) + + # Preprocess text + texts = sample["text"] + texts = [texts] if not isinstance(texts, list) else texts + capi = np.random.randint(0, len(texts)) + text = texts[capi] + text = tokenizer(text)[0] + + # Preprocess synthetic text + scapi = np.random.randint(0, len(sample["syn.json"]["syn_text"])) + syn_text = sample["syn.json"]["syn_text"][scapi] + syn_text = tokenizer(syn_text)[0] + + # Preprocess embeddings + if "npz" in sample: + image_emb = sample["npz"]["image_emb"][aug_idx] + text_emb_all = sample["npz"]["text_emb"] + elif "pth.gz" in sample: + image_emb = sample["pth.gz"]["image_emb"][aug_idx] + text_emb_all = sample["pth.gz"]["text_emb"] + text_emb = text_emb_all[capi] + syn_text_emb = text_emb_all[len(texts)+scapi] + if not isinstance(image_emb, torch.Tensor): + image_emb = torch.tensor(image_emb) + text_emb = torch.tensor(text_emb) + syn_text_emb = torch.tensor(syn_text_emb) + image_emb = image_emb.type(torch.float32) + text_emb = text_emb.type(torch.float32) + syn_text_emb = syn_text_emb.type(torch.float32) + + if mix_synthetic: + if np.random.rand() < mix_synthetic_ratio: + text = syn_text + text_emb = syn_text_emb + # No double loss on gt/syn captions + syn_text = [] + syn_text_emb = [] + + return { + 'image': image, + 'text': text, + 'image_emb': image_emb, + 'text_emb': text_emb, + "syn_text": syn_text, + 'syn_text_emb': syn_text_emb, + } + + def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): """Return function over iterator that groups key, value pairs into samples. @@ -342,13 +400,13 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni 'Please specify it via `--train-num-samples` if no dataset length info is present.') else: # Eval will just exhaust the iterator if the size is not specified. - num_samples = args.val_num_samples or 0 + num_samples = args.val_num_samples or 0 shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc if is_train and args.train_data_upsampling_factors is not None: assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." - + if resampled: pipeline = [ResampledShards2( input_shards, @@ -386,14 +444,53 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni # at this point, we have an iterator over the shards assigned to each worker wds.tarfile_to_samples(handler=log_and_continue), ]) - pipeline.extend([ - wds.select(filter_no_caption_or_no_image), - wds.decode("pilrgb", handler=log_and_continue), - wds.rename(image="jpg;png;jpeg;webp", text="txt"), - wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), - wds.to_tuple("image", "text"), - wds.batched(args.batch_size, partial=not is_train) - ]) + + if not args.dataset_reinforcement: + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.decode("pilrgba", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.to_tuple("image", "text"), + wds.batched(args.batch_size, partial=not is_train) + ]) + else: + # Setup DR transformations + with open(args.dataset_reinforcement_config, 'r') as f: + rconfig = json.load(f) + rconfig_aug = rconfig["reinforce"]["image_augmentation"] + # Replace augmentation parameters where DR is flexible. + # DR also supports additional augmentations in args.aug-cfg but not + # implemented here. + from torchvision.transforms import RandomResizedCrop, Normalize + del rconfig_aug["normalize"] + for t in preprocess_img.transforms: + if isinstance(t, Normalize): + rconfig_aug["normalize"] = {"mean": t.mean, "std": t.std} + if isinstance(t, RandomResizedCrop): + # One can also pass image_size to dr_transforms.reapply to support + # variable resolution training + rconfig_aug["random_resized_crop"].update({ + "size": t.size, + "interpolation": t.interpolation, + }) + dr_transforms = compose_from_config(rconfig_aug) + + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.decode("pilrgba", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + wds.map( + lambda sample: preprocess_dr( + sample, dr_transforms, + tokenizer, + args.dataset_reinforcement_mix_synthetic, + args.dataset_reinforcement_mix_synthetic_ratio, + ) + ), + wds.to_tuple("image", "text", "image_emb", "text_emb", "syn_text", "syn_text_emb"), + wds.batched(args.batch_size, partial=not is_train) + ]) dataset = wds.DataPipeline(*pipeline) @@ -541,7 +638,7 @@ def get_dataset_fn(data_path, dataset_type): f"Tried to figure out dataset type, but failed for extension {ext}.") else: raise ValueError(f"Unsupported dataset type: {dataset_type}") - + def get_data(args, preprocess_fns, epoch=0, tokenizer=None): preprocess_train, preprocess_val = preprocess_fns diff --git a/src/training/params.py b/src/training/params.py index c3d1930..d7d7023 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -440,6 +440,26 @@ def parse_args(args): default=None, help='Which pre-trained weights to distill from, if any.' ) + parser.add_argument( + "--distill-loss-weights", + type=float, + default=[1.0, 1.0], + nargs="+", + help='Tuple of [contrastive, distillation] loss weights if distillation is enabled.' + ) + parser.add_argument( + "--distill-teacher-dimension", + type=int, + default=[-1], + nargs="+", + help="Number of dimensions for each teacher. Default: [-1]." + ) + parser.add_argument( + "--distill-average-after-softmax", + default=False, + action="store_true", + help='For ensemble models in distillation, average logits after Softmax.' + ) parser.add_argument( "--use-bnb-linear", default=None, @@ -452,6 +472,37 @@ def parse_args(args): action="store_true", help='Use SigLip (sigmoid) loss.' ) + parser.add_argument( + "--dataset-reinforcement", + default=False, + action="store_true", + help="If true, load image/text embeddings and synthetic captions from webdataset." + ) + parser.add_argument( + "--dataset-reinforcement-config", + type=str, + default=None, + help="Pass the config file for dataset reinforcement." + ) + parser.add_argument( + "--distill-logit-scale", + type=int, + default=None, + help="Logit scale for distillation loss. None means fetch it from the model_out." + ) + parser.add_argument( + "--dataset-reinforcement-mix-synthetic", + default=False, + action="store_true", + help="Mix synthetic caption with ground-truth captions and randomly sample." + ) + parser.add_argument( + "--dataset-reinforcement-mix-synthetic-ratio", + type=float, + default=0.0, + help="Mixing ration for synthetic vs ground-truth captions." + "0.0: all ground-truth and 1.0 means all synthetic." + ) args = parser.parse_args(args) diff --git a/src/training/train.py b/src/training/train.py index a48a345..3a52abc 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -89,9 +89,12 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist if not args.skip_scheduler: scheduler(step) - images, texts = batch + images, texts = batch[:2] images = images.to(device=device, dtype=input_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) + if args.dataset_reinforcement and not args.dataset_reinforcement_mix_synthetic: + syn_texts = batch[4].to(device=device, non_blocking=True) + texts = torch.cat([texts, syn_texts[:, :texts.shape[-1]]], dim=0) data_time_m.update(time.time() - end) optimizer.zero_grad() @@ -104,6 +107,18 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist with torch.no_grad(): dist_model_out = dist_model(images, texts) model_out.update({f'dist_{k}': v for k, v in dist_model_out.items()}) + if args.dataset_reinforcement: + batch_size = images.shape[0] + model_out.update({ + 'dist_image_features': batch[2].to(device=device, non_blocking=True), + 'dist_text_features': batch[3].to(device=device, non_blocking=True), + }) + if not args.dataset_reinforcement_mix_synthetic: + model_out.update({ + "text_features": model_out["text_features"][:batch_size], + "syn_text_features": model_out["text_features"][batch_size:], + 'dist_syn_text_features': batch[5].to(device=device, non_blocking=True) + }) losses = loss(**model_out, output_dict=True) total_loss = sum(losses.values())