Commit f55a786e authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1081 canceled with stages
#!/bin/sh
config=$1
gpus=$2
output=$3
if [ -z $config ]
then
echo "No config file found! Run with "sh eval.sh [CONFIG_FILE] [NUM_GPUS] [OUTPUT_DIR] [OPTS]""
exit 0
fi
if [ -z $gpus ]
then
echo "Number of gpus not specified! Run with "sh eval.sh [CONFIG_FILE] [NUM_GPUS] [OUTPUT_DIR] [OPTS]""
exit 0
fi
if [ -z $output ]
then
echo "No output directory found! Run with "sh eval.sh [CONFIG_FILE] [NUM_GPUS] [OUTPUT_DIR] [OPTS]""
exit 0
fi
shift 3
opts=${@}
#ADE20k-150
python train_net.py --config $config \
--num-gpus $gpus \
--dist-url "auto" \
--eval-only \
OUTPUT_DIR $output/eval-ade150 \
MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/ade150.json" \
DATASETS.TEST \(\"ade20k_150_test_sem_seg\"\,\) \
MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
MODEL.WEIGHTS $output/model_final.pth \
$opts
#ADE20k-847
python train_net.py --config $config \
--num-gpus $gpus \
--dist-url "auto" \
--eval-only \
OUTPUT_DIR $output/eval-ade847 \
MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/ade847.json" \
DATASETS.TEST \(\"ade20k_full_sem_seg_freq_val_all\"\,\) \
MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
MODEL.WEIGHTS $output/model_final.pth \
$opts
#Pascal VOC
python train_net.py --config $config \
--num-gpus $gpus \
--dist-url "auto" \
--eval-only \
OUTPUT_DIR $output/eval-voc20 \
MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/voc20.json" \
DATASETS.TEST \(\"voc_2012_test_sem_seg\"\,\) \
MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
MODEL.WEIGHTS $output/model_final.pth \
$opts
#Pascal VOC-b
python train_net.py --config $config \
--num-gpus $gpus \
--dist-url "auto" \
--eval-only \
OUTPUT_DIR $output/eval-voc20b \
MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/voc20b.json" \
DATASETS.TEST \(\"voc_2012_test_background_sem_seg\"\,\) \
MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
MODEL.WEIGHTS $output/model_final.pth \
$opts
#Pascal Context 59
python train_net.py --config $config \
--num-gpus $gpus \
--dist-url "auto" \
--eval-only \
OUTPUT_DIR $output/eval-pc59 \
MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/pc59.json" \
DATASETS.TEST \(\"context_59_test_sem_seg\"\,\) \
MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
MODEL.WEIGHTS $output/model_final.pth \
$opts
#Pascal Context 459
python train_net.py --config $config \
--num-gpus $gpus \
--dist-url "auto" \
--eval-only \
OUTPUT_DIR $output/eval-pc459 \
MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/pc459.json" \
DATASETS.TEST \(\"context_459_test_sem_seg\"\,\) \
MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
MODEL.WEIGHTS $output/model_final.pth \
$opts
cat $output/eval-ade150/log.txt | grep copypaste
cat $output/eval-ade847/log.txt | grep copypaste
cat $output/eval-voc20/log.txt | grep copypaste
cat $output/eval-voc20b/log.txt | grep copypaste
cat $output/eval-pc59/log.txt | grep copypaste
cat $output/eval-pc459/log.txt | grep copypaste
\ No newline at end of file
# 模型唯一标识
modelCode=683
# 模型名称
modelName=sed_pytorch
# 模型描述
modelDescription=用于开放词汇语义分割的简单编码器-解码器-SED模型的推理和训练
# 应用场景
appScenario=训练,推理,科研,制造,医疗,家居,教育
# 框架类型
frameType=Pytorch
logs/
wandb/
models/
features/
results/
tests/data/
*.pt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
sync.sh
gpu1sync.sh
.idea
*.pdf
**/._*
**/*DS_*
**.jsonl
src/sbatch
src/misc
.vscode
src/debug
core.*
# Allow
!src/evaluation/misc/results_dbs/*
\ No newline at end of file
cff-version: 1.1.0
message: If you use this software, please cite it as below.
authors:
- family-names: Ilharco
given-names: Gabriel
- family-names: Wortsman
given-names: Mitchell
- family-names: Wightman
given-names: Ross
- family-names: Gordon
given-names: Cade
- family-names: Carlini
given-names: Nicholas
- family-names: Taori
given-names: Rohan
- family-names: Dave
given-names: Achal
- family-names: Shankar
given-names: Vaishaal
- family-names: Namkoong
given-names: Hongseok
- family-names: Miller
given-names: John
- family-names: Hajishirzi
given-names: Hannaneh
- family-names: Farhadi
given-names: Ali
- family-names: Schmidt
given-names: Ludwig
title: OpenCLIP
version: v0.1
doi: 10.5281/zenodo.5143773
date-released: 2021-07-28
## 2.10.1
* `hf-hub:org/model_id` support for loading models w/ config and weights in Hugging Face Hub
## 2.10.0
* Added a ViT-bigG-14 model.
* Added an up-to-date example slurm script for large training jobs.
* Added a option to sync logs and checkpoints to S3 during training.
* New options for LR schedulers, constant and constant with cooldown
* Fix wandb autoresuming when resume is not set
* ConvNeXt `base` & `base_w` pretrained models added
* `timm-` model prefix removed from configs
* `timm` augmentation + regularization (dropout / drop-path) supported
## 2.9.3
* Fix wandb collapsing multiple parallel runs into a single one
## 2.9.2
* Fix braceexpand memory explosion for complex webdataset urls
## 2.9.1
* Fix release
## 2.9.0
* Add training feature to auto-resume from the latest checkpoint on restart via `--resume latest`
* Allow webp in webdataset
* Fix logging for number of samples when using gradient accumulation
* Add model configs for convnext xxlarge
## 2.8.2
* wrapped patchdropout in a torch.nn.Module
## 2.8.1
* relax protobuf dependency
* override the default patch dropout value in 'vision_cfg'
## 2.8.0
* better support for HF models
* add support for gradient accumulation
* CI fixes
* add support for patch dropout
* add convnext configs
## 2.7.0
* add multilingual H/14 xlm roberta large
## 2.6.1
* fix setup.py _read_reqs
## 2.6.0
* Make openclip training usable from pypi.
* Add xlm roberta large vit h 14 config.
## 2.5.0
* pretrained B/32 xlm roberta base: first multilingual clip trained on laion5B
* pretrained B/32 roberta base: first clip trained using an HF text encoder
## 2.4.1
* Add missing hf_tokenizer_name in CLIPTextCfg.
## 2.4.0
* Fix #211, missing RN50x64 config. Fix type of dropout param for ResNet models
* Bring back LayerNorm impl that casts to input for non bf16/fp16
* zero_shot.py: set correct tokenizer based on args
* training/params.py: remove hf params and get them from model config
## 2.3.1
* Implement grad checkpointing for hf model.
* custom_text: True if hf_model_name is set
* Disable hf tokenizer parallelism
## 2.3.0
* Generalizable Text Transformer with HuggingFace Models (@iejMac)
## 2.2.0
* Support for custom text tower
* Add checksum verification for pretrained model weights
## 2.1.0
* lot including sota models, bfloat16 option, better loading, better metrics
## 1.2.0
* ViT-B/32 trained on Laion2B-en
* add missing openai RN50x64 model
## 1.1.1
* ViT-B/16+
* Add grad checkpointing support
* more robust data loader
Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
Ludwig Schmidt
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
include src/open_clip/bpe_simple_vocab_16e6.txt.gz
include src/open_clip/model_configs/*.json
install: ## [Local development] Upgrade pip, install requirements, install package.
python -m pip install -U pip
python -m pip install -e .
install-training:
python -m pip install -r requirements-training.txt
install-test: ## [Local development] Install test requirements
python -m pip install -r requirements-test.txt
test: ## [Local development] Run unit tests
python -m pytest -x -s -v tests
This diff is collapsed.
[pytest]
markers =
regression_test
pytest-split==0.8.0
pytest==7.2.0
transformers
timm==0.6.11
torch>=1.9.0
torchvision
webdataset>=0.2.5
regex
ftfy
tqdm
pandas
braceexpand
huggingface_hub
transformers
timm
fsspec
torch>=1.9.0
torchvision
regex
ftfy
tqdm
huggingface_hub
sentencepiece
protobuf==3.20.*
timm
""" Setup
"""
from setuptools import setup, find_packages
from codecs import open
from os import path
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
def _read_reqs(relpath):
fullpath = path.join(path.dirname(__file__), relpath)
with open(fullpath) as f:
return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))]
REQUIREMENTS = _read_reqs("requirements.txt")
TRAINING_REQUIREMENTS = _read_reqs("requirements-training.txt")
exec(open('src/open_clip/version.py').read())
setup(
name='open_clip_torch',
version=__version__,
description='OpenCLIP',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/mlfoundations/open_clip',
author='',
author_email='',
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
# Note that this is a string of words separated by whitespace, not a list.
keywords='CLIP pretrained',
package_dir={'': 'src'},
packages=find_packages(where='src'),
include_package_data=True,
install_requires=REQUIREMENTS,
extras_require={
"training": TRAINING_REQUIREMENTS,
},
python_requires='>=3.7',
)
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .transform import image_transform, AugmentationCfg
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .tokenizer import HFTokenizer, tokenize
HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def get_tokenizer(model_name):
if model_name.startswith(HF_HUB_PREFIX):
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
else:
config = get_model_config(model_name)
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
return tokenizer
def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
pretrained_cfg = config['preprocess_cfg']
model_cfg = config['model_cfg']
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
pretrained_cfg = {}
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
if pretrained_image:
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
model.to(device=device)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
if jit:
model = torch.jit.script(model)
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
)
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: str,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
f' Use open_clip.list_pretrained() to find one.')
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
cache_dir=cache_dir,
)
if not return_transform:
return model
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)
return model, preprocess
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment