# -------------------------------------------------------- # Borrow code from: # https://github.com/openai/CLIP/tree/main/clip/clip.py # -------------------------------------------------------- import hashlib import os import urllib import warnings from typing import List, Union import oneflow as flow import torch from flowvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor from PIL import Image from tqdm import tqdm from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer try: from flowvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() # noqa: _MODELS = { "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa: E501 "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa: E501 "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa: E501 "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa: E501 "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", # noqa: E501 "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa: E501 "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa: E501 "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", # noqa: E501 "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", # noqa: E501 } def _download(url: str, root: str): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: return download_target else: warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; " "re-downloading the file" ) with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm( total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024, ) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") return download_target def _convert_image_to_rgb(image): return image.convert("RGB") def _transform(n_px): return Compose( [ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), _convert_image_to_rgb, ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ] ) def available_models() -> List[str]: """Returns the names of available CLIP models""" return list(_MODELS.keys()) def load( name: str, device: Union[str, torch.device] = "cuda" if flow.cuda.is_available() else "cpu", download_root: str = None, ): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict device : Union[str, torch.device] The device to put the loaded model download_root: str path to download the model files; by default, it uses "~/.cache/clip" Returns ------- model : flow.nn.Module The CLIP model preprocess : Callable[[PIL.Image], flow.Tensor] A flowvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) elif os.path.isfile(name): model_path = name else: raise RuntimeError(f"Model {name} not found; available models = {available_models()}") with open(model_path, "rb") as opened_file: try: # loading JIT archive model = torch.jit.load(opened_file, map_location="cpu").eval() state_dict = None except RuntimeError: # loading saved state dict state_dict = torch.load(opened_file, map_location="cpu") model = build_model(state_dict or model.state_dict()).to(device) if str(device) == "cpu": model.float() return model, _transform(model.visual.img_size) def tokenize( texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False ) -> Union[flow.IntTensor, flow.LongTensor]: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] result = flow.zeros(len(all_tokens), context_length, dtype=flow.int) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: if truncate: tokens = tokens[:context_length] tokens[-1] = eot_token else: raise RuntimeError( f"Input {texts[i]} is too long for context length {context_length}" ) result[i, : len(tokens)] = flow.tensor(tokens, dtype=flow.int) return result