Commit 6f43e8fa authored by mashun1's avatar mashun1
Browse files

open_clip

parents
Pipeline #1689 canceled with stages
import logging
def setup_logging(log_file, level, include_host=False):
if include_host:
import socket
hostname = socket.gethostname()
formatter = logging.Formatter(
f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
else:
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
logging.root.setLevel(level)
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
logger.setLevel(level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logging.root.addHandler(stream_handler)
if log_file:
file_handler = logging.FileHandler(filename=log_file)
file_handler.setFormatter(formatter)
logging.root.addHandler(file_handler)
import glob
import logging
import os
import re
import subprocess
import sys
import random
from datetime import datetime
from functools import partial
import numpy as np
import torch
from torch import optim
from torch.cuda.amp import GradScaler
try:
import wandb
except ImportError:
wandb = None
try:
import torch.utils.tensorboard as tensorboard
except ImportError:
tensorboard = None
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss
from open_clip_train.data import get_data
from open_clip_train.distributed import is_master, init_distributed_device, broadcast_object
from open_clip_train.logger import setup_logging
from open_clip_train.params import parse_args
from open_clip_train.scheduler import cosine_lr, const_lr, const_lr_cooldown
from open_clip_train.train import train_one_epoch, evaluate
from open_clip_train.file_utils import pt_load, check_exists, start_sync_process, remote_sync
LATEST_CHECKPOINT_NAME = "epoch_latest.pt"
def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def get_latest_checkpoint(path: str, remote : bool):
# as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders
if remote:
result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print(result)
if result.returncode == 1:
return None
checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]]
else:
checkpoints = glob.glob(path + '**/*.pt', recursive=True)
if checkpoints:
checkpoints = sorted(checkpoints, key=natural_key)
return checkpoints[-1]
return None
def main(args):
args = parse_args(args)
if torch.cuda.is_available():
# This enables tf32 on Ampere GPUs which is only 8% slower than
# float16 and almost as accurate as float32
# This was a default in pytorch until 1.12
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# fully initialize distributed device environment
device = init_distributed_device(args)
# get the name of the experiments
if args.name is None:
# sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
model_name_safe = args.model.replace('/', '-')
date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
if args.distributed:
# sync date_str from master to all ranks
date_str = broadcast_object(args, date_str)
args.name = '-'.join([
date_str,
f"model_{model_name_safe}",
f"lr_{args.lr}",
f"b_{args.batch_size}",
f"j_{args.workers}",
f"p_{args.precision}",
])
resume_latest = args.resume == 'latest'
log_base_path = os.path.join(args.logs, args.name)
args.log_path = None
if is_master(args, local=args.log_local):
os.makedirs(log_base_path, exist_ok=True)
log_filename = f'out-{args.rank}' if args.log_local else 'out.log'
args.log_path = os.path.join(log_base_path, log_filename)
if os.path.exists(args.log_path) and not resume_latest:
print(
"Error. Experiment already exists. Use --name {} to specify a new experiment."
)
return -1
# Setup text logger
args.log_level = logging.DEBUG if args.debug else logging.INFO
setup_logging(args.log_path, args.log_level)
# Setup wandb, tensorboard, checkpoint logging
args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
args.checkpoint_path = os.path.join(log_base_path, "checkpoints")
if is_master(args):
args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else ''
for dirname in [args.tensorboard_path, args.checkpoint_path]:
if dirname:
os.makedirs(dirname, exist_ok=True)
else:
args.tensorboard_path = ''
if resume_latest:
resume_from = None
checkpoint_path = args.checkpoint_path
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if args.remote_sync is not None:
checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints")
if args.save_most_recent:
print('Error. Cannot use save-most-recent with remote_sync and resume latest.')
return -1
if args.remote_sync_protocol != 's3':
print('Error. Sync protocol not supported when using resume latest.')
return -1
if is_master(args):
# Checking for existing checkpoint via master rank only. It is possible for
# different rank processes to see different files if a shared file-system is under
# stress, however it's very difficult to fully work around such situations.
if args.save_most_recent:
# if --save-most-recent flag is set, look for latest at a fixed filename
resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME)
if not os.path.exists(resume_from):
# If no latest checkpoint has been saved yet, don't try to resume
resume_from = None
else:
# otherwise, list checkpoint dir contents and pick the newest checkpoint
resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None)
if resume_from:
logging.info(f'Found latest resume checkpoint at {resume_from}.')
else:
logging.info(f'No latest resume checkpoint found in {checkpoint_path}.')
if args.distributed:
# sync found checkpoint path to all ranks
resume_from = broadcast_object(args, resume_from)
args.resume = resume_from
if args.copy_codebase:
copy_codebase(args)
# start the sync proces if remote-sync is not None
remote_sync_process = None
if is_master(args) and args.remote_sync is not None:
# first make sure it works
result = remote_sync(
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
if result:
logging.info('remote sync successful.')
else:
logging.info('Error: remote sync failed. Exiting.')
return -1
# if all looks good, start a process to do this every args.remote_sync_frequency seconds
remote_sync_process = start_sync_process(
args.remote_sync_frequency,
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
remote_sync_process.start()
if args.precision == 'fp16':
logging.warning(
'It is recommended to use AMP mixed-precision instead of FP16. '
'FP16 support needs further verification and tuning, especially for train.')
if args.horovod:
logging.info(
f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.'
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
elif args.distributed:
logging.info(
f'Running in distributed mode with multiple processes. Device: {args.device}.'
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
else:
logging.info(f'Running with a single process. Device {args.device}.')
dist_model = None
args.distill = args.distill_model is not None and args.distill_pretrained is not None
if args.distill:
#FIXME: support distillation with grad accum.
assert args.accum_freq == 1
#FIXME: support distillation with coca.
assert 'coca' not in args.model.lower()
if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1:
# arg is nargs, single (square) image size list -> int
args.force_image_size = args.force_image_size[0]
random_seed(args.seed, 0)
model_kwargs = {}
if args.siglip:
model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP
model_kwargs['init_logit_bias'] = -10
model, preprocess_train, preprocess_val = create_model_and_transforms(
args.model,
args.pretrained,
precision=args.precision,
device=device,
jit=args.torchscript,
force_quick_gelu=args.force_quick_gelu,
force_custom_text=args.force_custom_text,
force_patch_dropout=args.force_patch_dropout,
force_image_size=args.force_image_size,
image_mean=args.image_mean,
image_std=args.image_std,
image_interpolation=args.image_interpolation,
image_resize_mode=args.image_resize_mode, # only effective for inference
aug_cfg=args.aug_cfg,
pretrained_image=args.pretrained_image,
output_dict=True,
**model_kwargs,
)
if args.distill:
# FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms.
dist_model, _, _ = create_model_and_transforms(
args.distill_model,
args.distill_pretrained,
device=device,
precision=args.precision,
output_dict=True,
)
if args.use_bnb_linear is not None:
print('=> using a layer from bitsandbytes.\n'
' this is an experimental feature which requires two extra pip installs\n'
' pip install bitsandbytes triton'
' please make sure to use triton 2.0.0')
import bitsandbytes as bnb
from open_clip.utils import replace_linear
print(f'=> replacing linear layers with {args.use_bnb_linear}')
linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear)
replace_linear(model, linear_replacement_cls)
model = model.to(device)
random_seed(args.seed, args.rank)
if args.trace:
model = trace_model(model, batch_size=args.batch_size, device=device)
if args.lock_image:
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
model.lock_image_tower(
unlocked_groups=args.lock_image_unlocked_groups,
freeze_bn_stats=args.lock_image_freeze_bn_stats)
if args.lock_text:
model.lock_text_tower(
unlocked_layers=args.lock_text_unlocked_layers,
freeze_layer_norm=args.lock_text_freeze_layer_norm)
if args.grad_checkpointing:
model.set_grad_checkpointing()
if is_master(args):
logging.info("Model:")
logging.info(f"{str(model)}")
logging.info("Params:")
params_file = os.path.join(args.logs, args.name, "params.txt")
with open(params_file, "w") as f:
for name in sorted(vars(args)):
val = getattr(args, name)
logging.info(f" {name}: {val}")
f.write(f"{name}: {val}\n")
if args.distributed and not args.horovod:
if args.use_bn_sync:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
ddp_args = {}
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args['static_graph'] = True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)
if args.distill:
dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args)
# create optimizer and scaler
optimizer = None
scaler = None
if args.train_data or args.dataset_type == "synthetic":
assert not args.trace, 'Cannot train with traced model'
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n, p: not exclude(n, p)
named_parameters = list(model.named_parameters())
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
if args.horovod:
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
scaler = GradScaler() if args.precision == "amp" else None
# optionally resume from a checkpoint
start_epoch = 0
if args.resume is not None:
checkpoint = pt_load(args.resume, map_location='cpu')
if 'epoch' in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd)
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
if scaler is not None and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
# initialize datasets
tokenizer = get_tokenizer(args.model)
data = get_data(
args,
(preprocess_train, preprocess_val),
epoch=start_epoch,
tokenizer=tokenizer,
)
assert len(data), 'At least one train or eval dataset must be specified.'
# create scheduler if train
scheduler = None
if 'train' in data and optimizer is not None:
total_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs
if args.lr_scheduler == "cosine":
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
elif args.lr_scheduler == "const":
scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps)
elif args.lr_scheduler == "const-cooldown":
assert args.epochs_cooldown is not None,\
"Please specify the number of cooldown epochs for this lr schedule."
cooldown_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs_cooldown
scheduler = const_lr_cooldown(
optimizer, args.lr, args.warmup, total_steps,
cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end)
else:
logging.error(
f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.')
exit(1)
# determine if this worker should save logs and checkpoints. only do so if it is rank == 0
args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args)
writer = None
if args.save_logs and args.tensorboard:
assert tensorboard is not None, "Please install tensorboard."
writer = tensorboard.SummaryWriter(args.tensorboard_path)
if args.wandb and is_master(args):
assert wandb is not None, 'Please install wandb.'
logging.debug('Starting wandb.')
args.train_sz = data["train"].dataloader.num_samples
if args.val_data is not None:
args.val_sz = data["val"].dataloader.num_samples
# you will have to configure this for your project!
wandb.init(
project=args.wandb_project_name,
name=args.name,
id=args.name,
notes=args.wandb_notes,
tags=[],
resume='auto' if args.resume == "latest" else None,
config=vars(args),
)
if args.debug:
wandb.watch(model, log='all')
wandb.save(params_file)
logging.debug('Finished loading wandb.')
# Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models.
# For compatibility, we save state_dict() of the original model, which shares the
# weights without the prefix.
original_model = model
if args.torchcompile:
logging.info('Compiling model...')
model = torch.compile(original_model)
if 'train' not in data:
# If using int8, convert to inference mode.
if args.use_bnb_linear is not None:
from open_clip.utils import convert_int8_model_to_inference_mode
convert_int8_model_to_inference_mode(model)
# Evaluate.
evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer)
return
loss = create_loss(args)
for epoch in range(start_epoch, args.epochs):
if is_master(args):
logging.info(f'Start epoch {epoch}')
train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer)
completed_epoch = epoch + 1
if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')):
evaluate(model, data, completed_epoch, args, tb_writer=writer, tokenizer=tokenizer)
# Saving checkpoints.
if args.save_logs:
checkpoint_dict = {
"epoch": completed_epoch,
"name": args.name,
"state_dict": original_model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if scaler is not None:
checkpoint_dict["scaler"] = scaler.state_dict()
if completed_epoch == args.epochs or (
args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
):
torch.save(
checkpoint_dict,
os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
)
if args.delete_previous_checkpoint:
previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt")
if os.path.exists(previous_checkpoint):
os.remove(previous_checkpoint)
if args.save_most_recent:
# try not to corrupt the latest checkpoint if save fails
tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt")
latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME)
torch.save(checkpoint_dict, tmp_save_path)
os.replace(tmp_save_path, latest_save_path)
if args.wandb and is_master(args):
wandb.finish()
# run a final sync.
if remote_sync_process is not None:
logging.info('Final remote sync.')
remote_sync_process.terminate()
result = remote_sync(
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
if result:
logging.info('Final remote sync successful.')
else:
logging.info('Final remote sync failed.')
def copy_codebase(args):
from shutil import copytree, ignore_patterns
new_code_path = os.path.join(args.logs, args.name, "code")
if os.path.exists(new_code_path):
print(
f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
)
return -1
print(f"Copying codebase to {new_code_path}")
current_code_path = os.path.realpath(__file__)
for _ in range(3):
current_code_path = os.path.dirname(current_code_path)
copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
print("Done copying code.")
return 1
if __name__ == "__main__":
main(sys.argv[1:])
import argparse
import ast
def get_default_params(model_name):
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
model_name = model_name.lower()
if "vit" in model_name:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
else:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
for value in values:
key, value = value.split('=')
try:
kw[key] = ast.literal_eval(value)
except ValueError:
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
setattr(namespace, self.dest, kw)
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--train-data",
type=str,
default=None,
help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.",
)
parser.add_argument(
"--train-data-upsampling-factors",
type=str,
default=None,
help=(
"When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. "
"Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) "
"By default, datapoints are sampled uniformly regardless of the dataset sizes."
)
)
parser.add_argument(
"--val-data",
type=str,
default=None,
help="Path to file(s) with validation data",
)
parser.add_argument(
"--train-num-samples",
type=int,
default=None,
help="Number of samples in dataset. Required for webdataset if not available in info file.",
)
parser.add_argument(
"--val-num-samples",
type=int,
default=None,
help="Number of samples in dataset. Useful for webdataset if not available in info file.",
)
parser.add_argument(
"--dataset-type",
choices=["webdataset", "csv", "synthetic", "auto"],
default="auto",
help="Which type of dataset to process."
)
parser.add_argument(
"--dataset-resampled",
default=False,
action="store_true",
help="Whether to use sampling with replacement for webdataset shard selection."
)
parser.add_argument(
"--csv-separator",
type=str,
default="\t",
help="For csv-like datasets, which separator to use."
)
parser.add_argument(
"--csv-img-key",
type=str,
default="filepath",
help="For csv-like datasets, the name of the key for the image paths."
)
parser.add_argument(
"--csv-caption-key",
type=str,
default="title",
help="For csv-like datasets, the name of the key for the captions."
)
parser.add_argument(
"--imagenet-val",
type=str,
default=None,
help="Path to imagenet val set for conducting zero shot evaluation.",
)
parser.add_argument(
"--imagenet-v2",
type=str,
default=None,
help="Path to imagenet v2 for conducting zero shot evaluation.",
)
parser.add_argument(
"--logs",
type=str,
default="./logs/",
help="Where to store tensorboard logs. Use None to avoid storing logs.",
)
parser.add_argument(
"--log-local",
action="store_true",
default=False,
help="log files on local master, otherwise global master only.",
)
parser.add_argument(
"--name",
type=str,
default=None,
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument(
"--workers", type=int, default=4, help="Number of dataloader workers per GPU."
)
parser.add_argument(
"--batch-size", type=int, default=64, help="Batch size per GPU."
)
parser.add_argument(
"--epochs", type=int, default=32, help="Number of epochs to train for."
)
parser.add_argument(
"--epochs-cooldown", type=int, default=None,
help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards."
)
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument(
"--warmup", type=int, default=10000, help="Number of steps to warmup for."
)
parser.add_argument(
"--use-bn-sync",
default=False,
action="store_true",
help="Whether to use batch norm sync.")
parser.add_argument(
"--skip-scheduler",
action="store_true",
default=False,
help="Use this flag to skip the learning rate decay.",
)
parser.add_argument(
"--lr-scheduler",
type=str,
default='cosine',
help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine",
)
parser.add_argument(
"--lr-cooldown-end", type=float, default=0.0,
help="End learning rate for cooldown schedule. Default: 0"
)
parser.add_argument(
"--lr-cooldown-power", type=float, default=1.0,
help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)"
)
parser.add_argument(
"--save-frequency", type=int, default=1, help="How often to save checkpoints."
)
parser.add_argument(
"--save-most-recent",
action="store_true",
default=False,
help="Always save the most recent model trained to epoch_latest.pt.",
)
parser.add_argument(
"--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
)
parser.add_argument(
"--val-frequency", type=int, default=1, help="How often to run evaluation with val data."
)
parser.add_argument(
"--resume",
default=None,
type=str,
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"--precision",
choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "pure_bf16", "pure_fp16", "fp32"],
default="amp",
help="Floating point precision."
)
parser.add_argument(
"--model",
type=str,
default="RN50",
help="Name of the vision backbone to use.",
)
parser.add_argument(
"--pretrained",
default='',
type=str,
help="Use a pretrained CLIP model weights with the specified tag or file path.",
)
parser.add_argument(
"--pretrained-image",
default=False,
action='store_true',
help="Load imagenet pretrained weights for image tower backbone if available.",
)
parser.add_argument(
"--lock-image",
default=False,
action='store_true',
help="Lock full image tower by disabling gradients.",
)
parser.add_argument(
"--lock-image-unlocked-groups",
type=int,
default=0,
help="Leave last n image tower layer groups unlocked.",
)
parser.add_argument(
"--lock-image-freeze-bn-stats",
default=False,
action='store_true',
help="Freeze BatchNorm running stats in image tower for any locked layers.",
)
parser.add_argument(
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override default image mean value of dataset')
parser.add_argument(
'--image-std', type=float, nargs='+', default=None, metavar='STD',
help='Override default image std deviation of of dataset')
parser.add_argument(
'--image-interpolation',
default=None, type=str, choices=['bicubic', 'bilinear', 'random'],
help="Override default image resize interpolation"
)
parser.add_argument(
'--image-resize-mode',
default=None, type=str, choices=['shortest', 'longest', 'squash'],
help="Override default image resize (& crop) mode during inference"
)
parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs)
parser.add_argument(
"--grad-checkpointing",
default=False,
action='store_true',
help="Enable gradient checkpointing.",
)
parser.add_argument(
"--local-loss",
default=False,
action="store_true",
help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)"
)
parser.add_argument(
"--gather-with-grad",
default=False,
action="store_true",
help="enable full distributed gradient for feature gather"
)
parser.add_argument(
'--force-image-size', type=int, nargs='+', default=None,
help='Override default image size'
)
parser.add_argument(
"--force-quick-gelu",
default=False,
action='store_true',
help="Force use of QuickGELU activation for non-OpenAI transformer models.",
)
parser.add_argument(
"--force-patch-dropout",
default=None,
type=float,
help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper",
)
parser.add_argument(
"--force-custom-text",
default=False,
action='store_true',
help="Force use of CustomTextCLIP model (separate text-tower).",
)
parser.add_argument(
"--torchscript",
default=False,
action='store_true',
help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
)
parser.add_argument(
"--torchcompile",
default=False,
action='store_true',
help="torch.compile() the model, requires pytorch 2.0 or later.",
)
parser.add_argument(
"--trace",
default=False,
action='store_true',
help="torch.jit.trace the model for inference / eval only",
)
parser.add_argument(
"--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps."
)
# arguments for distributed training
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--report-to",
default='',
type=str,
help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
)
parser.add_argument(
"--wandb-notes",
default='',
type=str,
help="Notes if logging with wandb"
)
parser.add_argument(
"--wandb-project-name",
type=str,
default='open-clip',
help="Name of the project if logging with wandb.",
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="If true, more information is logged."
)
parser.add_argument(
"--copy-codebase",
default=False,
action="store_true",
help="If true, we copy the entire base on the log directory, and execute from there."
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training."
)
parser.add_argument(
"--ddp-static-graph",
default=False,
action='store_true',
help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)."
)
parser.add_argument(
"--seed", type=int, default=0, help="Default random seed."
)
parser.add_argument(
"--grad-clip-norm", type=float, default=None, help="Gradient clip."
)
parser.add_argument(
"--lock-text",
default=False,
action='store_true',
help="Lock full text tower by disabling gradients.",
)
parser.add_argument(
"--lock-text-unlocked-layers",
type=int,
default=0,
help="Leave last n text tower layer groups unlocked.",
)
parser.add_argument(
"--lock-text-freeze-layer-norm",
default=False,
action='store_true',
help="Freeze LayerNorm running stats in text tower for any locked layers.",
)
parser.add_argument(
"--log-every-n-steps",
type=int,
default=100,
help="Log every n steps to tensorboard/console/wandb.",
)
parser.add_argument(
"--coca-caption-loss-weight",
type=float,
default=2.0,
help="Weight assigned to caption loss in CoCa."
)
parser.add_argument(
"--coca-contrastive-loss-weight",
type=float,
default=1.0,
help="Weight assigned to contrastive loss when training CoCa."
)
parser.add_argument(
"--remote-sync",
type=str,
default=None,
help="Optinoally sync with a remote path specified by this arg",
)
parser.add_argument(
"--remote-sync-frequency",
type=int,
default=300,
help="How frequently to sync to a remote directly if --remote-sync is not None.",
)
parser.add_argument(
"--remote-sync-protocol",
choices=["s3", "fsspec"],
default="s3",
help="How to do the remote sync backup if --remote-sync is not None.",
)
parser.add_argument(
"--delete-previous-checkpoint",
default=False,
action="store_true",
help="If true, delete previous checkpoint after storing a new one."
)
parser.add_argument(
"--distill-model",
default=None,
help='Which model arch to distill from, if any.'
)
parser.add_argument(
"--distill-pretrained",
default=None,
help='Which pre-trained weights to distill from, if any.'
)
parser.add_argument(
"--use-bnb-linear",
default=None,
help='Replace the network linear layers from the bitsandbytes library. '
'Allows int8 training/inference, etc.'
)
parser.add_argument(
"--siglip",
default=False,
action="store_true",
help='Use SigLip (sigmoid) loss.'
)
args = parser.parse_args(args)
# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)
return args
import torch
from contextlib import suppress
def get_autocast(precision):
if precision == 'amp':
return torch.cuda.amp.autocast
elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
# amp_bfloat16 is more stable than amp float16 for clip training
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return suppress
import argparse
import torch
import open_clip
import pandas as pd
from torch.utils.flop_counter import FlopCounterMode
try:
import fvcore
except:
fvcore = None
parser = argparse.ArgumentParser(description='OpenCLIP Profiler')
# benchmark specific args
parser.add_argument('--model', metavar='NAME', default='',
help='model(s) to profile')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for results')
parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore'])
parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling')
def profile_fvcore(
model,
image_input_size=(3, 224, 224),
text_input_size=(77,),
batch_size=1,
detailed=False,
force_cpu=False
):
if force_cpu:
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input))
aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input))
if detailed:
fcs = fvcore.nn.flop_count_str(fca)
print(fcs)
return fca.total() / batch_size, aca.total() / batch_size
def profile_fvcore_text(
model,
text_input_size=(77,),
batch_size=1,
detailed=False,
force_cpu=False
):
if force_cpu:
model = model.to('cpu')
device = next(model.parameters()).device
example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
fca = fvcore.nn.FlopCountAnalysis(model, example_input)
aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
if detailed:
fcs = fvcore.nn.flop_count_str(fca)
print(fcs)
return fca.total() / batch_size, aca.total() / batch_size
def profile_fvcore_image(
model,
image_input_size=(3, 224, 224),
batch_size=1,
detailed=False,
force_cpu=False
):
if force_cpu:
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
fca = fvcore.nn.FlopCountAnalysis(model, example_input)
aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
if detailed:
fcs = fvcore.nn.flop_count_str(fca)
print(fcs)
return fca.total() / batch_size, aca.total() / batch_size
def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False):
"""Profile the image encoder using torch.utils.flop_counter"""
if force_cpu:
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
flop_counter = FlopCounterMode()
with flop_counter:
model(example_input)
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
return total_flops / batch_size
def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False):
"""Profile the text encoder using torch.utils.flop_counter"""
if force_cpu:
model = model.to('cpu')
device = next(model.parameters()).device
example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
flop_counter = FlopCounterMode()
with flop_counter:
model(example_input)
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
return total_flops / batch_size
def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False):
"""Profile the full model using torch.utils.flop_counter"""
if force_cpu:
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
flop_counter = FlopCounterMode()
with flop_counter:
model(image_input, text_input)
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
return total_flops / batch_size
def count_params(model):
return sum(m.numel() for m in model.parameters())
def profile_model(model_name, batch_size=1, profiler='torch'):
assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported'
if profiler == 'fvcore':
assert fvcore is not None, 'Please install fvcore.'
model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
if isinstance(model.visual.image_size, (tuple, list)):
image_input_size = (3,) + tuple(model.visual.image_size[-2:])
else:
image_input_size = (3, model.visual.image_size, model.visual.image_size)
text_input_size = (77,)
if hasattr(model, 'context_length') and model.context_length:
text_input_size = (model.context_length,)
results = {}
results['model'] = model_name
results['image_size'] = image_input_size[1]
model_cfg = open_clip.get_model_config(model_name)
if model_cfg:
vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg'])
text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg'])
results['image_width'] = int(vision_cfg.width)
results['text_width'] = int(text_cfg.width)
results['embed_dim'] = int(model_cfg['embed_dim'])
else:
results['image_width'] = 0
results['text_width'] = 0
results['embed_dim'] = 0
retries = 2
while retries:
retries -= 1
try:
results['mparams'] = round(count_params(model) / 1e6, 2)
results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
results['text_mparams'] = round(count_params(model.text) / 1e6, 2)
if profiler == 'fvcore':
macs, acts = profile_fvcore(
model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
image_macs, image_acts = profile_fvcore_image(
model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
text_macs, text_acts = profile_fvcore_text(
model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
results['gmacs'] = round(macs / 1e9, 2)
results['macts'] = round(acts / 1e6, 2)
results['image_gmacs'] = round(image_macs / 1e9, 2)
results['image_macts'] = round(image_acts / 1e6, 2)
results['text_gmacs'] = round(text_macs / 1e9, 2)
results['text_macts'] = round(text_acts / 1e6, 2)
elif profiler == 'torch':
image_flops = profile_torch_image(
model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
text_flops = profile_torch_text(
model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
total_flops = profile_torch(
model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
results['gflops'] = round(total_flops / 1e9, 2)
results['image_gflops'] = round(image_flops / 1e9, 2)
results['text_gflops'] = round(text_flops / 1e9, 2)
except RuntimeError as e:
pass
return results
def main():
args = parser.parse_args()
# FIXME accept a text file name to allow lists of models in txt/csv
if args.model == 'all':
parsed_model = open_clip.list_models()
else:
parsed_model = args.model.split(',')
results = []
models_with_errors = []
for m in parsed_model:
print('='*100)
print(f'Profiling {m}')
try:
row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler)
results.append(row)
except Exception as e:
print(f'Error profiling {m}: {e}')
import traceback
traceback.print_exc()
models_with_errors.append(m)
df = pd.DataFrame(results, columns=results[0].keys())
if 'gmacs' in df.columns:
df = df.sort_values(by=['gmacs', 'mparams', 'model'])
else:
df = df.sort_values(by=['gflops', 'mparams', 'model'])
print('='*100)
print('Done.')
print(df)
if args.results_file:
df.to_csv(args.results_file, index=False)
if models_with_errors:
print('Models with errors:', models_with_errors)
if __name__ == '__main__':
main()
import numpy as np
def assign_learning_rate(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def const_lr(optimizer, base_lr, warmup_length, steps):
def _lr_adjuster(step):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
lr = base_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster
def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
def _lr_adjuster(step):
start_cooldown_step = steps - cooldown_steps
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
if step < start_cooldown_step:
lr = base_lr
else:
e = step - start_cooldown_step
es = steps - start_cooldown_step
# linear decay if power == 1; polynomial decay otherwise;
decay = (1 - (e/es)) ** cooldown_power
lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster
def cosine_lr(optimizer, base_lr, warmup_length, steps):
def _lr_adjuster(step):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster
import json
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel
try:
import wandb
except ImportError:
wandb = None
from open_clip import get_input_dtype, CLIP, CustomTextCLIP
from open_clip_train.distributed import is_master
from open_clip_train.zero_shot import zero_shot_eval
from open_clip_train.precision import get_autocast
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def postprocess_clip_output(model_out):
return {
"image_features": model_out[0],
"text_features": model_out[1],
"logit_scale": model_out[2]
}
def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model
def backward(total_loss, scaler):
if scaler is not None:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None):
device = torch.device(args.device)
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)
model.train()
if args.distill:
dist_model.eval()
data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
if args.accum_freq > 1:
accum_images, accum_texts, accum_features = [], [], {}
losses_m = {}
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
i_accum = i // args.accum_freq
step = num_batches_per_epoch * epoch + i_accum
if not args.skip_scheduler:
scheduler(step)
images, texts = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.accum_freq == 1:
with autocast():
model_out = model(images, texts)
logit_scale = model_out["logit_scale"]
if args.distill:
with torch.no_grad():
dist_model_out = dist_model(images, texts)
model_out.update({f'dist_{k}': v for k, v in dist_model_out.items()})
losses = loss(**model_out, output_dict=True)
total_loss = sum(losses.values())
losses["loss"] = total_loss
backward(total_loss, scaler)
else:
# First, cache the features without any gradient tracking.
with torch.no_grad():
with autocast():
model_out = model(images, texts)
for f in ("logit_scale", "logit_bias"):
model_out.pop(f, None)
for key, val in model_out.items():
if key in accum_features:
accum_features[key].append(val)
else:
accum_features[key] = [val]
accum_images.append(images)
accum_texts.append(texts)
# If (i + 1) % accum_freq is not zero, move on to the next batch.
if ((i + 1) % args.accum_freq) > 0:
# FIXME this makes data time logging unreliable when accumulating
continue
# Now, ready to take gradients for the last accum_freq batches.
# Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
# Call backwards each time, but only step optimizer at the end.
optimizer.zero_grad()
for j in range(args.accum_freq):
images = accum_images[j]
texts = accum_texts[j]
with autocast():
model_out = model(images, texts)
inputs_no_accum = {}
inputs_no_accum["logit_scale"] = logit_scale = model_out.pop("logit_scale")
if "logit_bias" in model_out:
inputs_no_accum["logit_bias"] = model_out.pop("logit_bias")
inputs = {}
for key, val in accum_features.items():
accumulated = accum_features[key]
inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:])
losses = loss(**inputs, **inputs_no_accum, output_dict=True)
del inputs
del inputs_no_accum
total_loss = sum(losses.values())
losses["loss"] = total_loss
backward(total_loss, scaler)
if scaler is not None:
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
if args.grad_clip_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
else:
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
optimizer.step()
# reset gradient accum, if enabled
if args.accum_freq > 1:
accum_images, accum_texts, accum_features = [], [], {}
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i_accum + 1
if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
batch_size = len(images)
num_samples = batch_count * batch_size * args.accum_freq * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
for key, val in losses.items():
if key not in losses_m:
losses_m[key] = AverageMeter()
losses_m[key].update(val.item(), batch_size)
logit_scale_scalar = logit_scale.item()
loss_log = " ".join(
[
f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})"
for loss_name, loss_m in losses_m.items()
]
)
samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val
samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"scale": logit_scale_scalar,
"lr": optimizer.param_groups[0]["lr"]
}
log_data.update({name:val.val for name,val in losses_m.items()})
log_data = {"train/" + name: val for name, val in log_data.items()}
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
log_data['step'] = step # for backwards compatibility
wandb.log(log_data, step=step)
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for
def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):
metrics = {}
if not is_master(args):
return metrics
device = torch.device(args.device)
model.eval()
zero_shot_metrics = zero_shot_eval(model, data, epoch, args, tokenizer=tokenizer)
metrics.update(zero_shot_metrics)
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)
if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
dataloader = data['val'].dataloader
num_samples = 0
samples_per_val = dataloader.num_samples
# FIXME this does not scale past small eval datasets
# all_image_features @ all_text_features will blow up memory and compute very quickly
cumulative_loss = 0.0
cumulative_gen_loss = 0.0
all_image_features, all_text_features = [], []
with torch.inference_mode():
for i, batch in enumerate(dataloader):
images, texts = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
with autocast():
model_out = model(images, texts)
image_features = model_out["image_features"]
text_features = model_out["text_features"]
logit_scale = model_out["logit_scale"]
# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
# however, system RAM is easily exceeded and compute time becomes problematic
all_image_features.append(image_features.cpu())
all_text_features.append(text_features.cpu())
logit_scale = logit_scale.mean()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
batch_size = images.shape[0]
labels = torch.arange(batch_size, device=device).long()
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
gen_loss = maybe_compute_generative_loss(model_out)
cumulative_loss += total_loss * batch_size
num_samples += batch_size
if is_master(args) and (i % 100) == 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
f"Clip Loss: {cumulative_loss / num_samples:.6f}\t")
if gen_loss is not None:
cumulative_gen_loss += gen_loss * batch_size
logging.info(
f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t")
val_metrics = get_clip_metrics(
image_features=torch.cat(all_image_features),
text_features=torch.cat(all_text_features),
logit_scale=logit_scale.cpu(),
)
loss = cumulative_loss / num_samples
metrics.update(
{**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples}
)
if gen_loss is not None:
gen_loss = cumulative_gen_loss / num_samples
metrics.update({"val_generative_loss": gen_loss.item()})
if not metrics:
return metrics
logging.info(
f"Eval Epoch: {epoch} "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)
log_data = {"val/" + name: val for name, val in metrics.items()}
if args.save_logs:
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, epoch)
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
if args.wandb:
assert wandb is not None, 'Please install wandb.'
if 'train' in data:
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
step = num_batches_per_epoch * epoch
else:
step = None
log_data['epoch'] = epoch
wandb.log(log_data, step=step)
return metrics
def get_clip_metrics(image_features, text_features, logit_scale):
metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
return metrics
def maybe_compute_generative_loss(model_out):
if "logits" in model_out and "labels" in model_out:
token_logits = model_out["logits"]
token_labels = model_out["labels"]
return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels)
import logging
import torch
from tqdm import tqdm
from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
from open_clip_train.precision import get_autocast
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
def run(model, classifier, dataloader, args):
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)
with torch.inference_mode():
top1, top5, n = 0., 0., 0.
for images, target in tqdm(dataloader, unit_scale=args.batch_size):
images = images.to(device=args.device, dtype=input_dtype)
target = target.to(args.device)
with autocast():
# predict
output = model(image=images)
image_features = output['image_features'] if isinstance(output, dict) else output[0]
logits = 100. * image_features @ classifier
# measure accuracy
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
top1 += acc1
top5 += acc5
n += images.size(0)
top1 = (top1 / n)
top5 = (top5 / n)
return top1, top5
def zero_shot_eval(model, data, epoch, args, tokenizer=None):
if 'imagenet-val' not in data and 'imagenet-v2' not in data:
return {}
if args.zeroshot_frequency == 0:
return {}
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
return {}
if args.distributed and not args.horovod:
model = model.module
logging.info('Starting zero-shot imagenet.')
if tokenizer is None:
tokenizer = get_tokenizer(args.model)
logging.info('Building zero-shot classifier')
autocast = get_autocast(args.precision)
with autocast():
classifier = build_zero_shot_classifier(
model,
tokenizer=tokenizer,
classnames=IMAGENET_CLASSNAMES,
templates=OPENAI_IMAGENET_TEMPLATES,
num_classes_per_batch=10,
device=args.device,
use_tqdm=True,
)
logging.info('Using classifier')
results = {}
if 'imagenet-val' in data:
top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
results['imagenet-zeroshot-val-top1'] = top1
results['imagenet-zeroshot-val-top5'] = top5
if 'imagenet-v2' in data:
top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
results['imagenetv2-zeroshot-val-top1'] = top1
results['imagenetv2-zeroshot-val-top5'] = top5
logging.info('Finished zero-shot imagenet.')
return results
import requests
import torch
from PIL import Image
import hashlib
import tempfile
import unittest
from io import BytesIO
from pathlib import Path
from unittest.mock import patch
from urllib3 import HTTPResponse
from urllib3._collections import HTTPHeaderDict
import open_clip
from open_clip.pretrained import download_pretrained_from_url
class DownloadPretrainedTests(unittest.TestCase):
def create_response(self, data, status_code=200, content_type='application/octet-stream'):
fp = BytesIO(data)
headers = HTTPHeaderDict({
'Content-Type': content_type,
'Content-Length': str(len(data))
})
raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code)
return raw
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
with tempfile.TemporaryDirectory() as root:
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
local_file = Path(root) / 'RN50.pt'
local_file.write_bytes(file_contents)
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_not_called()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
local_file = Path(root) / 'RN50.pt'
local_file.write_bytes(b'corrupted pretrained model')
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_mlfoundations(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
with tempfile.TemporaryDirectory() as root:
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()
@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_hfh(self, urllib):
model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model')
tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model')
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3))
import pytest
import torch
from open_clip.hf_model import _POOLERS, HFTextEncoder
from transformers import AutoConfig
from transformers.modeling_outputs import BaseModelOutput
# test poolers
def test_poolers():
bs, sl, d = 2, 10, 5
h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d)
mask = torch.ones(bs, sl, dtype=torch.bool)
mask[:2, 6:] = False
x = BaseModelOutput(h)
for name, cls in _POOLERS.items():
pooler = cls()
res = pooler(x, mask)
assert res.shape == (bs, d), f"{name} returned wrong shape"
# test HFTextEncoder
@pytest.mark.parametrize("model_id", ["arampacha/roberta-tiny", "roberta-base", "xlm-roberta-base", "google/mt5-base"])
def test_pretrained_text_encoder(model_id):
bs, sl, d = 2, 10, 64
cfg = AutoConfig.from_pretrained(model_id)
model = HFTextEncoder(model_id, d, proj_type='linear')
x = torch.randint(0, cfg.vocab_size, (bs, sl))
with torch.no_grad():
emb = model(x)
assert emb.shape == (bs, d)
import os
import pytest
import torch
import open_clip
import util_test
os.environ['CUDA_VISIBLE_DEVICES'] = ''
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
# no need for the fusion performance here
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)
models_to_test = set(open_clip.list_models())
# testing excemptions
models_to_test = models_to_test.difference({
# not available with timm yet
# see https://github.com/mlfoundations/open_clip/issues/219
'convnext_xlarge',
'convnext_xxlarge',
'convnext_xxlarge_320',
'vit_medium_patch16_gap_256',
# exceeds GH runner memory limit
'ViT-bigG-14',
'ViT-e-14',
'mt5-xl-ViT-H-14',
'coca_base',
'coca_ViT-B-32',
'coca_roberta-ViT-B-32'
})
if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ:
external_model_list = os.environ['OPEN_CLIP_TEST_REG_MODELS']
with open(external_model_list, 'r') as f:
models_to_test = set(f.read().splitlines()).intersection(models_to_test)
print(f"Selected models from {external_model_list}: {models_to_test}")
# TODO: add "coca_ViT-B-32" onece https://github.com/pytorch/pytorch/issues/92073 gets fixed
models_to_test = list(models_to_test)
models_to_test.sort()
models_to_test = [(model_name, False) for model_name in models_to_test]
models_to_jit_test = {"ViT-B-32"}
models_to_jit_test = list(models_to_jit_test)
models_to_jit_test = [(model_name, True) for model_name in models_to_jit_test]
models_to_test_fully = models_to_test + models_to_jit_test
@pytest.mark.regression_test
@pytest.mark.parametrize("model_name,jit", models_to_test_fully)
def test_inference_with_data(
model_name,
jit,
pretrained = None,
pretrained_hf = False,
precision = 'fp32',
force_quick_gelu = False,
):
util_test.seed_all()
model, _, preprocess_val = open_clip.create_model_and_transforms(
model_name,
pretrained = pretrained,
precision = precision,
jit = jit,
force_quick_gelu = force_quick_gelu,
pretrained_hf = pretrained_hf
)
model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
input_dir, output_dir = util_test.get_data_dirs()
# text
input_text_path = os.path.join(input_dir, 'random_text.pt')
gt_text_path = os.path.join(output_dir, f'{model_id}_random_text.pt')
if not os.path.isfile(input_text_path):
pytest.skip(reason = f"missing test data, expected at {input_text_path}")
if not os.path.isfile(gt_text_path):
pytest.skip(reason = f"missing test data, expected at {gt_text_path}")
input_text = torch.load(input_text_path)
gt_text = torch.load(gt_text_path)
y_text = util_test.inference_text(model, model_name, input_text)
assert (y_text == gt_text).all(), f"text output differs @ {input_text_path}"
# image
image_size = model.visual.image_size
if not isinstance(image_size, tuple):
image_size = (image_size, image_size)
input_image_path = os.path.join(input_dir, f'random_image_{image_size[0]}_{image_size[1]}.pt')
gt_image_path = os.path.join(output_dir, f'{model_id}_random_image.pt')
if not os.path.isfile(input_image_path):
pytest.skip(reason = f"missing test data, expected at {input_image_path}")
if not os.path.isfile(gt_image_path):
pytest.skip(reason = f"missing test data, expected at {gt_image_path}")
input_image = torch.load(input_image_path)
gt_image = torch.load(gt_image_path)
y_image = util_test.inference_image(model, preprocess_val, input_image)
assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}"
if not jit:
model.eval()
model_out = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text)
if type(model) not in [open_clip.CLIP, open_clip.CustomTextCLIP]:
assert type(model_out) == dict
else:
model.output_dict = True
model_out_dict = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text)
assert (model_out_dict["image_features"] == model_out[0]).all()
assert (model_out_dict["text_features"] == model_out[1]).all()
assert (model_out_dict["logit_scale"] == model_out[2]).all()
model.output_dict = None
else:
model, _, preprocess_val = open_clip.create_model_and_transforms(
model_name,
pretrained = pretrained,
precision = precision,
jit = False,
force_quick_gelu = force_quick_gelu,
pretrained_hf = pretrained_hf
)
test_model = util_test.TestWrapper(model, model_name, output_dict=False)
test_model = torch.jit.script(test_model)
model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text)
assert model_out["test_output"].shape[-1] == 2
test_model = util_test.TestWrapper(model, model_name, output_dict=True)
test_model = torch.jit.script(test_model)
model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text)
assert model_out["test_output"].shape[-1] == 2
import torch
from PIL import Image
from open_clip.factory import get_tokenizer
import pytest
import open_clip
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
# no need for the fusion performance here
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)
test_simple_models = [
# model, pretrained, jit, force_custom_text
("ViT-B-32", "laion2b_s34b_b79k", False, False),
("ViT-B-32", "laion2b_s34b_b79k", True, False),
("ViT-B-32", "laion2b_s34b_b79k", True, True),
("roberta-ViT-B-32", "laion2b_s12b_b32k", False, False),
]
@pytest.mark.parametrize("model_type,pretrained,jit,force_custom_text", test_simple_models)
def test_inference_simple(
model_type,
pretrained,
jit,
force_custom_text,
):
model, _, preprocess = open_clip.create_model_and_transforms(
model_type,
pretrained=pretrained,
jit=jit,
force_custom_text=force_custom_text,
)
tokenizer = get_tokenizer(model_type)
current_dir = os.path.dirname(os.path.realpath(__file__))
image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
assert torch.allclose(text_probs.cpu()[0], torch.tensor([1.0, 0.0, 0.0]))
import pytest
from open_clip_train.data import get_dataset_size
@pytest.mark.parametrize(
"shards,expected_size",
[
('/path/to/shard.tar', 1),
('/path/to/shard_{000..000}.tar', 1),
('/path/to/shard_{000..009}.tar', 10),
('/path/to/shard_{000..009}_{000..009}.tar', 100),
('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11),
('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20),
(['/path/to/shard.tar'], 1),
(['/path/to/shard.tar', '/path/to/other_shard.tar'], 2),
]
)
def test_num_shards(shards, expected_size):
_, size = get_dataset_size(shards)
assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.'
import os
import sys
import pytest
import torch
from open_clip_train.main import main
os.environ["CUDA_VISIBLE_DEVICES"] = ""
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
# no need for the fusion performance here
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'RN50'
])
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training_coca():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'coca_ViT-B-32'
])
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training_mt5():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'mt5-base-ViT-B-32',
'--lock-text',
'--lock-text-unlocked-layers', '2'
])
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training_unfreezing_vit():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'ViT-B-32',
'--lock-image',
'--lock-image-unlocked-groups', '5',
'--accum-freq', '2'
])
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
def test_training_clip_with_jit():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'ViT-B-32',
'--torchscript'
])
import os
import pytest
import util_test
import collections
import tarfile
import io
from PIL import Image
from open_clip_train.data import get_wds_dataset
from open_clip_train.params import parse_args
from open_clip_train.main import random_seed
TRAIN_NUM_SAMPLES = 10_000
RTOL = 0.2
# NOTE: we use two test tar files, which are created on the fly and saved to data/input.
# 000.tar has 10 samples, and the captions are 000_0, 000_1, ..., 000_9
# 001.tar has 5 samples, and the captions are 001_0, 001_1, ..., 001_4
def build_inputs(test_name):
base_input_dir, _ = util_test.get_data_dirs()
input_dir = os.path.join(base_input_dir, test_name)
os.makedirs(input_dir, exist_ok=True)
def save_tar(idx, num_samples):
filename = os.path.join(input_dir, f'test_data_{idx:03d}.tar')
tar = tarfile.open(filename, 'w')
for sample_idx in range(num_samples):
# Image
image = Image.new('RGB', (32, 32))
info = tarfile.TarInfo(f'{sample_idx}.png')
bio = io.BytesIO()
image.save(bio, format='png')
size = bio.tell()
bio.seek(0)
info.size = size
tar.addfile(info, bio)
# Caption
info = tarfile.TarInfo(f'{sample_idx}.txt')
bio = io.BytesIO()
bio.write(f'{idx:03d}_{sample_idx}'.encode('utf-8'))
size = bio.tell()
bio.seek(0)
info.size = size
tar.addfile(info, bio)
tar.close()
save_tar(0, 10)
save_tar(1, 5)
return input_dir
def build_params(input_shards, seed=0):
args = parse_args([])
args.train_data = input_shards
args.train_num_samples = TRAIN_NUM_SAMPLES
args.dataset_resampled = True
args.seed = seed
args.workers = 1
args.world_size = 1
args.batch_size = 1
random_seed(seed)
preprocess_img = lambda x: x
tokenizer = lambda x: [x.strip()]
return args, preprocess_img, tokenizer
def get_dataloader(input_shards):
args, preprocess_img, tokenizer = build_params(input_shards)
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
return dataloader
def test_single_source():
"""Test webdataset with a single tar file."""
input_dir = build_inputs('single_source')
input_shards = os.path.join(input_dir, 'test_data_000.tar')
dataloader = get_dataloader(input_shards)
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL)
def test_two_sources():
"""Test webdataset with a single two tar files."""
input_dir = build_inputs('two_sources')
input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar')
dataloader = get_dataloader(input_shards)
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'
def test_two_sources_same_weights():
"""Test webdataset with a two tar files, using --train-data-weights=1::1."""
input_dir = build_inputs('two_sources_same_weights')
input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
args, preprocess_img, tokenizer = build_params(input_shards)
args.train_data_upsampling_factors = '1::1'
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'
def test_two_sources_with_upsampling():
"""Test webdataset with a two tar files with upsampling."""
input_dir = build_inputs('two_sources_with_upsampling')
input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
args, preprocess_img, tokenizer = build_params(input_shards)
args.train_data_upsampling_factors = '1::2'
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
if key.startswith('000'):
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 20, RTOL), f'{key}, {count}'
else:
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL), f'{key}, {count}'
import os
import random
import numpy as np
from PIL import Image
import torch
if __name__ != '__main__':
import open_clip
os.environ['CUDA_VISIBLE_DEVICES'] = ''
def seed_all(seed = 0):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=False)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def inference_text(model, model_name, batches):
y = []
tokenizer = open_clip.get_tokenizer(model_name)
with torch.no_grad():
for x in batches:
x = tokenizer(x)
y.append(model.encode_text(x))
return torch.stack(y)
def inference_image(model, preprocess_val, batches):
y = []
with torch.no_grad():
for x in batches:
x = torch.stack([preprocess_val(img) for img in x])
y.append(model.encode_image(x))
return torch.stack(y)
def forward_model(model, model_name, preprocess_val, image_batch, text_batch):
y = []
tokenizer = open_clip.get_tokenizer(model_name)
with torch.no_grad():
for x_im, x_txt in zip(image_batch, text_batch):
x_im = torch.stack([preprocess_val(im) for im in x_im])
x_txt = tokenizer(x_txt)
y.append(model(x_im, x_txt))
if type(y[0]) == dict:
out = {}
for key in y[0].keys():
out[key] = torch.stack([batch_out[key] for batch_out in y])
else:
out = []
for i in range(len(y[0])):
out.append(torch.stack([batch_out[i] for batch_out in y]))
return out
def random_image_batch(batch_size, size):
h, w = size
data = np.random.randint(255, size = (batch_size, h, w, 3), dtype = np.uint8)
return [ Image.fromarray(d) for d in data ]
def random_text_batch(batch_size, min_length = 75, max_length = 75):
t = open_clip.tokenizer.SimpleTokenizer()
# every token decoded as string, exclude SOT and EOT, replace EOW with space
token_words = [
x[1].replace('</w>', ' ')
for x in t.decoder.items()
if x[0] not in t.all_special_ids
]
# strings of randomly chosen tokens
return [
''.join(random.choices(
token_words,
k = random.randint(min_length, max_length)
))
for _ in range(batch_size)
]
def create_random_text_data(
path,
min_length = 75,
max_length = 75,
batches = 1,
batch_size = 1
):
text_batches = [
random_text_batch(batch_size, min_length, max_length)
for _ in range(batches)
]
print(f"{path}")
torch.save(text_batches, path)
def create_random_image_data(path, size, batches = 1, batch_size = 1):
image_batches = [
random_image_batch(batch_size, size)
for _ in range(batches)
]
print(f"{path}")
torch.save(image_batches, path)
def get_data_dirs(make_dir = True):
data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
input_dir = os.path.join(data_dir, 'input')
output_dir = os.path.join(data_dir, 'output')
if make_dir:
os.makedirs(input_dir, exist_ok = True)
os.makedirs(output_dir, exist_ok = True)
assert os.path.isdir(data_dir), f"data directory missing, expected at {input_dir}"
assert os.path.isdir(data_dir), f"data directory missing, expected at {output_dir}"
return input_dir, output_dir
def create_test_data_for_model(
model_name,
pretrained = None,
precision = 'fp32',
jit = False,
pretrained_hf = False,
force_quick_gelu = False,
create_missing_input_data = True,
batches = 1,
batch_size = 1,
overwrite = False
):
model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
input_dir, output_dir = get_data_dirs()
output_file_text = os.path.join(output_dir, f'{model_id}_random_text.pt')
output_file_image = os.path.join(output_dir, f'{model_id}_random_image.pt')
text_exists = os.path.exists(output_file_text)
image_exists = os.path.exists(output_file_image)
if not overwrite and text_exists and image_exists:
return
seed_all()
model, _, preprocess_val = open_clip.create_model_and_transforms(
model_name,
pretrained = pretrained,
precision = precision,
jit = jit,
force_quick_gelu = force_quick_gelu,
pretrained_hf = pretrained_hf
)
# text
if overwrite or not text_exists:
input_file_text = os.path.join(input_dir, 'random_text.pt')
if create_missing_input_data and not os.path.exists(input_file_text):
create_random_text_data(
input_file_text,
batches = batches,
batch_size = batch_size
)
assert os.path.isfile(input_file_text), f"missing input data, expected at {input_file_text}"
input_data_text = torch.load(input_file_text)
output_data_text = inference_text(model, model_name, input_data_text)
print(f"{output_file_text}")
torch.save(output_data_text, output_file_text)
# image
if overwrite or not image_exists:
size = model.visual.image_size
if not isinstance(size, tuple):
size = (size, size)
input_file_image = os.path.join(input_dir, f'random_image_{size[0]}_{size[1]}.pt')
if create_missing_input_data and not os.path.exists(input_file_image):
create_random_image_data(
input_file_image,
size,
batches = batches,
batch_size = batch_size
)
assert os.path.isfile(input_file_image), f"missing input data, expected at {input_file_image}"
input_data_image = torch.load(input_file_image)
output_data_image = inference_image(model, preprocess_val, input_data_image)
print(f"{output_file_image}")
torch.save(output_data_image, output_file_image)
def create_test_data(
models,
batches = 1,
batch_size = 1,
overwrite = False
):
models = list(set(models).difference({
# not available with timm
# see https://github.com/mlfoundations/open_clip/issues/219
'timm-convnext_xlarge',
'timm-vit_medium_patch16_gap_256'
}).intersection(open_clip.list_models()))
models.sort()
print(f"generating test data for:\n{models}")
for model_name in models:
print(model_name)
create_test_data_for_model(
model_name,
batches = batches,
batch_size = batch_size,
overwrite = overwrite
)
return models
def _sytem_assert(string):
assert os.system(string) == 0
class TestWrapper(torch.nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(self, model, model_name, output_dict=True) -> None:
super().__init__()
self.model = model
self.output_dict = output_dict
if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]:
self.model.output_dict = self.output_dict
config = open_clip.get_model_config(model_name)
self.head = torch.nn.Linear(config["embed_dim"], 2)
def forward(self, image, text):
x = self.model(image, text)
x = x['image_features'] if self.output_dict else x[0]
assert x is not None # remove Optional[], type refinement for torchscript
out = self.head(x)
return {"test_output": out}
def main(args):
global open_clip
import importlib
import shutil
import subprocess
import argparse
parser = argparse.ArgumentParser(description = "Populate test data directory")
parser.add_argument(
'-a', '--all',
action = 'store_true',
help = "create test data for all models"
)
parser.add_argument(
'-m', '--model',
type = str,
default = [],
nargs = '+',
help = "model(s) to create test data for"
)
parser.add_argument(
'-f', '--model_list',
type = str,
help = "path to a text file containing a list of model names, one model per line"
)
parser.add_argument(
'-s', '--save_model_list',
type = str,
help = "path to save the list of models that data was generated for"
)
parser.add_argument(
'-g', '--git_revision',
type = str,
help = "git revision to generate test data for"
)
parser.add_argument(
'--overwrite',
action = 'store_true',
help = "overwrite existing output data"
)
parser.add_argument(
'-n', '--num_batches',
default = 1,
type = int,
help = "amount of data batches to create (default: 1)"
)
parser.add_argument(
'-b', '--batch_size',
default = 1,
type = int,
help = "test data batch size (default: 1)"
)
args = parser.parse_args(args)
model_list = []
if args.model_list is not None:
with open(args.model_list, 'r') as f:
model_list = f.read().splitlines()
if not args.all and len(args.model) < 1 and len(model_list) < 1:
print("error: at least one model name is required")
parser.print_help()
parser.exit(1)
if args.git_revision is not None:
stash_output = subprocess.check_output(['git', 'stash']).decode().splitlines()
has_stash = len(stash_output) > 0 and stash_output[0] != 'No local changes to save'
current_branch = subprocess.check_output(['git', 'branch', '--show-current'])
if len(current_branch) < 1:
# not on a branch -> detached head
current_branch = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
current_branch = current_branch.splitlines()[0].decode()
try:
_sytem_assert(f'git checkout {args.git_revision}')
except AssertionError as e:
_sytem_assert(f'git checkout -f {current_branch}')
if has_stash:
os.system(f'git stash pop')
raise e
open_clip = importlib.import_module('open_clip')
models = open_clip.list_models() if args.all else args.model + model_list
try:
models = create_test_data(
models,
batches = args.num_batches,
batch_size = args.batch_size,
overwrite = args.overwrite
)
finally:
if args.git_revision is not None:
test_dir = os.path.join(os.path.dirname(__file__), 'data')
test_dir_ref = os.path.join(os.path.dirname(__file__), 'data_ref')
if os.path.exists(test_dir_ref):
shutil.rmtree(test_dir_ref, ignore_errors = True)
if os.path.exists(test_dir):
os.rename(test_dir, test_dir_ref)
_sytem_assert(f'git checkout {current_branch}')
if has_stash:
os.system(f'git stash pop')
os.rename(test_dir_ref, test_dir)
if args.save_model_list is not None:
print(f"Saving model list as {args.save_model_list}")
with open(args.save_model_list, 'w') as f:
for m in models:
print(m, file=f)
if __name__ == '__main__':
import sys
main(sys.argv[1:])
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment