Commit cbc25585 authored by limm's avatar limm
Browse files

add mmpretrain/ part

parent 1baf0566
Pipeline #2801 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import mmengine
from mmengine.utils import digit_version
from .apis import * # noqa: F401, F403
from .version import __version__
mmcv_minimum_version = '2.0.0'
mmcv_maximum_version = '2.4.0'
mmcv_version = digit_version(mmcv.__version__)
mmengine_minimum_version = '0.8.3'
mmengine_maximum_version = '1.0.0'
mmengine_version = digit_version(mmengine.__version__)
assert (mmcv_version >= digit_version(mmcv_minimum_version)
and mmcv_version < digit_version(mmcv_maximum_version)), \
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
assert (mmengine_version >= digit_version(mmengine_minimum_version)
and mmengine_version < digit_version(mmengine_maximum_version)), \
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
f'Please install mmengine>={mmengine_minimum_version}, ' \
f'<{mmengine_maximum_version}.'
__all__ = ['__version__']
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseInferencer
from .feature_extractor import FeatureExtractor
from .image_caption import ImageCaptionInferencer
from .image_classification import ImageClassificationInferencer
from .image_retrieval import ImageRetrievalInferencer
from .model import (ModelHub, get_model, inference_model, init_model,
list_models)
from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
TextToImageRetrievalInferencer)
from .nlvr import NLVRInferencer
from .visual_grounding import VisualGroundingInferencer
from .visual_question_answering import VisualQuestionAnsweringInferencer
__all__ = [
'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub',
'ImageClassificationInferencer', 'ImageRetrievalInferencer',
'FeatureExtractor', 'ImageCaptionInferencer',
'TextToImageRetrievalInferencer', 'VisualGroundingInferencer',
'VisualQuestionAnsweringInferencer', 'ImageToTextRetrievalInferencer',
'BaseInferencer', 'NLVRInferencer'
]
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from math import ceil
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from mmengine.config import Config
from mmengine.dataset import default_collate
from mmengine.fileio import get_file_backend
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from mmpretrain.structures import DataSample
from mmpretrain.utils import track
from .model import get_model, list_models
ModelType = Union[BaseModel, str, Config]
InputType = Union[str, np.ndarray, list]
class BaseInferencer:
"""Base inferencer for various tasks.
The BaseInferencer provides the standard workflow for inference as follows:
1. Preprocess the input data by :meth:`preprocess`.
2. Forward the data to the model by :meth:`forward`. ``BaseInferencer``
assumes the model inherits from :class:`mmengine.models.BaseModel` and
will call `model.test_step` in :meth:`forward` by default.
3. Visualize the results by :meth:`visualize`.
4. Postprocess and return the results by :meth:`postprocess`.
When we call the subclasses inherited from BaseInferencer (not overriding
``__call__``), the workflow will be executed in order.
All subclasses of BaseInferencer could define the following class
attributes for customization:
- ``preprocess_kwargs``: The keys of the kwargs that will be passed to
:meth:`preprocess`.
- ``forward_kwargs``: The keys of the kwargs that will be passed to
:meth:`forward`
- ``visualize_kwargs``: The keys of the kwargs that will be passed to
:meth:`visualize`
- ``postprocess_kwargs``: The keys of the kwargs that will be passed to
:meth:`postprocess`
All attributes mentioned above should be a ``set`` of keys (strings),
and each key should not be duplicated. Actually, :meth:`__call__` will
dispatch all the arguments to the corresponding methods according to the
``xxx_kwargs`` mentioned above.
Subclasses inherited from ``BaseInferencer`` should implement
:meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`:
- _init_pipeline: Return a callable object to preprocess the input data.
- visualize: Visualize the results returned by :meth:`forward`.
- postprocess: Postprocess the results returned by :meth:`forward` and
:meth:`visualize`.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``cls.list_models()`` and you can also query it in
:doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
device_map (str | dict | None): A map that specifies where each
submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every
submodule of it will be sent to the same device. You can use
`device_map="auto"` to automatically generate the device map.
Defaults to None.
offload_folder (str | None): If the `device_map` contains any value
`"disk"`, the folder where we will offload weights.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
"""
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = set()
postprocess_kwargs: set = set()
def __init__(self,
model: ModelType,
pretrained: Union[bool, str] = True,
device: Union[str, torch.device, None] = None,
device_map=None,
offload_folder=None,
**kwargs) -> None:
if isinstance(model, BaseModel):
if isinstance(pretrained, str):
load_checkpoint(model, pretrained, map_location='cpu')
if device_map is not None:
from .utils import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_folder=offload_folder)
elif device is not None:
model.to(device)
else:
model = get_model(
model,
pretrained,
device=device,
device_map=device_map,
offload_folder=offload_folder,
**kwargs)
model.eval()
self.config = model._config
self.model = model
self.pipeline = self._init_pipeline(self.config)
self.visualizer = None
def __call__(
self,
inputs,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs,
) -> dict:
"""Call the inferencer.
Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
**kwargs: Key words arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
(
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)
ori_inputs = self._inputs_to_list(inputs)
inputs = self.preprocess(
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
preds = []
for data in track(
inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)):
preds.extend(self.forward(data, **forward_kwargs))
visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
results = self.postprocess(preds, visualization, return_datasamples,
**postprocess_kwargs)
return results
def _inputs_to_list(self, inputs: InputType) -> list:
"""Preprocess the inputs to a list.
Cast the input data to a list of data.
- list or tuple: return inputs
- str:
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.
- other: return a list with one item.
Args:
inputs (str | array | list): Inputs for the inferencer.
Returns:
list: List of input for the :meth:`preprocess`.
"""
if isinstance(inputs, str):
backend = get_file_backend(inputs)
if hasattr(backend, 'isdir') and backend.isdir(inputs):
# Backends like HttpsBackend do not implement `isdir`, so only
# those backends that implement `isdir` could accept the inputs
# as a directory
file_list = backend.list_dir_or_file(inputs, list_dir=False)
inputs = [
backend.join_path(inputs, file) for file in file_list
]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
return list(inputs)
def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs):
"""Process the inputs into a model-feedable format.
Customize your preprocess by overriding this method. Preprocess should
return an iterable object, of which each item will be used as the
input of ``model.test_step``.
``BaseInferencer.preprocess`` will return an iterable chunked data,
which will be used in __call__ like this:
.. code-block:: python
def __call__(self, inputs, batch_size=1, **kwargs):
chunked_data = self.preprocess(inputs, batch_size, **kwargs)
for batch in chunked_data:
preds = self.forward(batch, **kwargs)
Args:
inputs (InputsType): Inputs given by user.
batch_size (int): batch size. Defaults to 1.
Yields:
Any: Data processed by the ``pipeline`` and ``default_collate``.
"""
chunked_data = self._get_chunk_data(
map(self.pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
@torch.no_grad()
def forward(self, inputs: Union[dict, tuple], **kwargs):
"""Feed the inputs to the model."""
return self.model.test_step(inputs)
def visualize(self,
inputs: list,
preds: List[DataSample],
show: bool = False,
**kwargs) -> List[np.ndarray]:
"""Visualize predictions.
Customize your visualization by overriding this method. visualize
should return visualization results, which could be np.ndarray or any
other objects.
Args:
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
preds (Any): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
Returns:
List[np.ndarray]: Visualization results.
"""
if show:
raise NotImplementedError(
f'The `visualize` method of {self.__class__.__name__} '
'is not implemented.')
@abstractmethod
def postprocess(
self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasample=False,
**kwargs,
) -> dict:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.
Customize your postprocess by overriding this method. Make sure
``postprocess`` will return a dict with visualization results and
inference results.
Args:
preds (List[Dict]): Predictions of the model.
visualization (np.ndarray): Visualized predictions.
return_datasample (bool): Whether to return results as datasamples.
Defaults to False.
Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``
- ``visualization (Any)``: Returned by :meth:`visualize`
- ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
"""
@abstractmethod
def _init_pipeline(self, cfg: Config) -> Callable:
"""Initialize the test pipeline.
Return a pipeline to handle various input data, such as ``str``,
``np.ndarray``. It is an abstract method in BaseInferencer, and should
be implemented in subclasses.
The returned pipeline will be used to process a single data.
It will be used in :meth:`preprocess` like this:
.. code-block:: python
def preprocess(self, inputs, batch_size, **kwargs):
...
dataset = map(self.pipeline, dataset)
...
"""
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
"""Get batch data from dataset.
Args:
inputs (Iterable): An iterable dataset.
chunk_size (int): Equivalent to batch size.
Yields:
list: batch data.
"""
inputs_iter = iter(inputs)
while True:
try:
chunk_data = []
for _ in range(chunk_size):
processed_data = next(inputs_iter)
chunk_data.append(processed_data)
yield chunk_data
except StopIteration:
if chunk_data:
yield chunk_data
break
def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]:
"""Dispatch kwargs to preprocess(), forward(), visualize() and
postprocess() according to the actual demands.
Returns:
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
forward, visualize and postprocess respectively.
"""
# Ensure each argument only matches one function
method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \
self.visualize_kwargs | self.postprocess_kwargs
union_kwargs = method_kwargs | set(kwargs.keys())
if union_kwargs != method_kwargs:
unknown_kwargs = union_kwargs - method_kwargs
raise ValueError(
f'unknown argument {unknown_kwargs} for `preprocess`, '
'`forward`, `visualize` and `postprocess`')
preprocess_kwargs = {}
forward_kwargs = {}
visualize_kwargs = {}
postprocess_kwargs = {}
for key, value in kwargs.items():
if key in self.preprocess_kwargs:
preprocess_kwargs[key] = value
if key in self.forward_kwargs:
forward_kwargs[key] = value
if key in self.visualize_kwargs:
visualize_kwargs[key] = value
if key in self.postprocess_kwargs:
postprocess_kwargs[key] = value
return (
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
)
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List models defined in metafile of corresponding packages.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union
import torch
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from .base import BaseInferencer, InputType
from .model import list_models
class FeatureExtractor(BaseInferencer):
"""The inferencer for extract features.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``FeatureExtractor.list_models()`` and you can also query it in
:doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
>>> from mmpretrain import FeatureExtractor
>>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0]
>>> for feat in feats:
>>> print(feat.shape)
torch.Size([256, 56, 56])
torch.Size([512, 28, 28])
torch.Size([1024, 14, 14])
torch.Size([2048, 7, 7])
""" # noqa: E501
def __call__(self,
inputs: InputType,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (str | array | list): The image path or array, or a list of
images.
batch_size (int): Batch size. Defaults to 1.
**kwargs: Other keyword arguments accepted by the `extract_feat`
method of the model.
Returns:
tensor | Tuple[tensor]: The extracted features.
"""
ori_inputs = self._inputs_to_list(inputs)
inputs = self.preprocess(ori_inputs, batch_size=batch_size)
preds = []
for data in inputs:
preds.extend(self.forward(data, **kwargs))
return preds
@torch.no_grad()
def forward(self, inputs: Union[dict, tuple], **kwargs):
inputs = self.model.data_preprocessor(inputs, False)['inputs']
outputs = self.model.extract_feat(inputs, **kwargs)
def scatter(feats, index):
if isinstance(feats, torch.Tensor):
return feats[index]
else:
# Sequence of tensor
return type(feats)([scatter(item, index) for item in feats])
results = []
for i in range(inputs.shape[0]):
results.append(scatter(outputs, i))
return results
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
def load_image(input_):
img = imread(input_)
if img is None:
raise ValueError(f'Failed to read image {input_}.')
return dict(
img=img,
img_shape=img.shape[:2],
ori_shape=img.shape[:2],
)
pipeline = Compose([load_image, self.pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self):
raise NotImplementedError(
"The FeatureExtractor doesn't support visualization.")
def postprocess(self):
raise NotImplementedError(
"The FeatureExtractor doesn't need postprocessing.")
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern)
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Callable, List, Optional
import numpy as np
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample
from .base import BaseInferencer, InputType
from .model import list_models
class ImageCaptionInferencer(BaseInferencer):
"""The inferencer for image caption.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``ImageCaptionInferencer.list_models()`` and you can also
query it in :doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
>>> from mmpretrain import ImageCaptionInferencer
>>> inferencer = ImageCaptionInferencer('blip-base_3rdparty_caption')
>>> inferencer('demo/cat-dog.png')[0]
{'pred_caption': 'a puppy and a cat sitting on a blanket'}
""" # noqa: E501
visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
def __call__(self,
images: InputType,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
images (str | array | list): The image path or array, or a list of
images.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the short edge of the image to the
specified length before visualization. Defaults to None.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
Returns:
list: The inference results.
"""
return super().__call__(images, return_datasamples, batch_size,
**kwargs)
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
def load_image(input_):
img = imread(input_)
if img is None:
raise ValueError(f'Failed to read image {input_}.')
return dict(
img=img,
img_shape=img.shape[:2],
ori_shape=img.shape[:2],
)
pipeline = Compose([load_image, self.pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self,
ori_inputs: List[InputType],
preds: List[DataSample],
show: bool = False,
wait_time: int = 0,
resize: Optional[int] = None,
show_dir=None):
if not show and show_dir is None:
return None
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
image = imread(input_)
if isinstance(input_, str):
# The image loaded from path is BGR format.
image = image[..., ::-1]
name = Path(input_).stem
else:
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_image_caption(
image,
data_sample,
resize=resize,
show=show,
wait_time=wait_time,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
results.append({'pred_caption': data_sample.get('pred_caption')})
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Image Caption')
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample
from .base import BaseInferencer, InputType, ModelType
from .model import list_models
class ImageClassificationInferencer(BaseInferencer):
"""The inferencer for image classification.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``ImageClassificationInferencer.list_models()`` and you can also
query it in :doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
1. Use a pre-trained model in MMPreTrain to inference an image.
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> inferencer('demo/demo.JPEG')
[{'pred_score': array([...]),
'pred_label': 65,
'pred_score': 0.6649367809295654,
'pred_class': 'sea snake'}]
2. Use a config file and checkpoint to inference multiple images on GPU,
and save the visualization results in a folder.
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer(
model='configs/resnet/resnet50_8xb32_in1k.py',
pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
device='cuda')
>>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/")
""" # noqa: E501
visualize_kwargs: set = {
'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir',
'wait_time'
}
def __init__(self,
model: ModelType,
pretrained: Union[bool, str] = True,
device: Union[str, torch.device, None] = None,
classes=None,
**kwargs) -> None:
super().__init__(
model=model, pretrained=pretrained, device=device, **kwargs)
if classes is not None:
self.classes = classes
else:
self.classes = getattr(self.model, '_dataset_meta',
{}).get('classes')
def __call__(self,
inputs: InputType,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (str | array | list): The image path or array, or a list of
images.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the short edge of the image to the
specified length before visualization. Defaults to None.
rescale_factor (float, optional): Rescale the image by the rescale
factor for visualization. This is helpful when the image is too
large or too small for visualization. Defaults to None.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
Returns:
list: The inference results.
"""
return super().__call__(
inputs,
return_datasamples=return_datasamples,
batch_size=batch_size,
**kwargs)
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
def load_image(input_):
img = imread(input_)
if img is None:
raise ValueError(f'Failed to read image {input_}.')
return dict(
img=img,
img_shape=img.shape[:2],
ori_shape=img.shape[:2],
)
pipeline = Compose([load_image, self.pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self,
ori_inputs: List[InputType],
preds: List[DataSample],
show: bool = False,
wait_time: int = 0,
resize: Optional[int] = None,
rescale_factor: Optional[float] = None,
draw_score=True,
show_dir=None):
if not show and show_dir is None:
return None
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
image = imread(input_)
if isinstance(input_, str):
# The image loaded from path is BGR format.
image = image[..., ::-1]
name = Path(input_).stem
else:
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_cls(
image,
data_sample,
classes=self.classes,
resize=resize,
show=show,
wait_time=wait_time,
rescale_factor=rescale_factor,
draw_gt=False,
draw_pred=True,
draw_score=draw_score,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
pred_scores = data_sample.pred_score
pred_score = float(torch.max(pred_scores).item())
pred_label = torch.argmax(pred_scores).item()
result = {
'pred_scores': pred_scores.detach().cpu().numpy(),
'pred_label': pred_label,
'pred_score': pred_score,
}
if self.classes is not None:
result['pred_class'] = self.classes[pred_label]
results.append(result)
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Image Classification')
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import BaseDataset, Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample
from .base import BaseInferencer, InputType, ModelType
from .model import list_models
class ImageRetrievalInferencer(BaseInferencer):
"""The inferencer for image to image retrieval.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``ImageRetrievalInferencer.list_models()`` and you can also
query it in :doc:`/modelzoo_statistics`.
prototype (str | list | dict | DataLoader, BaseDataset): The images to
be retrieved. It can be the following types:
- str: The directory of the the images.
- list: A list of path of the images.
- dict: A config dict of the a prototype dataset.
- BaseDataset: A prototype dataset.
- DataLoader: A data loader to load the prototype data.
prototype_cache (str, optional): The path of the generated prototype
features. If exists, directly load the cache instead of re-generate
the prototype features. If not exists, save the generated features
to the path. Defaults to None.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
>>> from mmpretrain import ImageRetrievalInferencer
>>> inferencer = ImageRetrievalInferencer(
... 'resnet50-arcface_inshop',
... prototype='./demo/',
... prototype_cache='img_retri.pth')
>>> inferencer('demo/cat-dog.png', topk=2)[0][1]
{'match_score': tensor(0.4088, device='cuda:0'),
'sample_idx': 3,
'sample': {'img_path': './demo/dog.jpg'}}
""" # noqa: E501
visualize_kwargs: set = {
'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
}
postprocess_kwargs: set = {'topk'}
def __init__(
self,
model: ModelType,
prototype,
prototype_cache=None,
prepare_batch_size=8,
pretrained: Union[bool, str] = True,
device: Union[str, torch.device, None] = None,
**kwargs,
) -> None:
super().__init__(
model=model, pretrained=pretrained, device=device, **kwargs)
self.prototype_dataset = self._prepare_prototype(
prototype, prototype_cache, prepare_batch_size)
def _prepare_prototype(self, prototype, cache=None, batch_size=8):
from mmengine.dataset import DefaultSampler
from torch.utils.data import DataLoader
def build_dataloader(dataset):
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=default_collate,
sampler=DefaultSampler(dataset, shuffle=False),
persistent_workers=False,
)
if isinstance(prototype, str):
# A directory path of images
prototype = dict(
type='CustomDataset', with_label=False, data_root=prototype)
if isinstance(prototype, list):
test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
dataset = BaseDataset(
lazy_init=True, serialize_data=False, pipeline=test_pipeline)
dataset.data_list = [{
'sample_idx': i,
'img_path': file
} for i, file in enumerate(prototype)]
dataset._fully_initialized = True
dataloader = build_dataloader(dataset)
elif isinstance(prototype, dict):
# A config of dataset
from mmpretrain.registry import DATASETS
test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
prototype.setdefault('pipeline', test_pipeline)
dataset = DATASETS.build(prototype)
dataloader = build_dataloader(dataset)
elif isinstance(prototype, DataLoader):
dataset = prototype.dataset
dataloader = prototype
elif isinstance(prototype, BaseDataset):
dataset = prototype
dataloader = build_dataloader(dataset)
else:
raise TypeError(f'Unsupported prototype type {type(prototype)}.')
if cache is not None and Path(cache).exists():
self.model.prototype = cache
else:
self.model.prototype = dataloader
self.model.prepare_prototype()
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
if cache is None:
logger.info('The prototype has been prepared, you can use '
'`save_prototype` to dump it into a pickle '
'file for the future usage.')
elif not Path(cache).exists():
self.save_prototype(cache)
logger.info(f'The prototype has been saved at {cache}.')
return dataset
def save_prototype(self, path):
self.model.dump_prototype(path)
def __call__(self,
inputs: InputType,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (str | array | list): The image path or array, or a list of
images.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the long edge of the image to the
specified length before visualization. Defaults to None.
draw_score (bool): Whether to draw the match scores.
Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
Returns:
list: The inference results.
"""
return super().__call__(inputs, return_datasamples, batch_size,
**kwargs)
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline
def preprocess(self, inputs: List[InputType], batch_size: int = 1):
def load_image(input_):
img = imread(input_)
if img is None:
raise ValueError(f'Failed to read image {input_}.')
return dict(
img=img,
img_shape=img.shape[:2],
ori_shape=img.shape[:2],
)
pipeline = Compose([load_image, self.pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self,
ori_inputs: List[InputType],
preds: List[DataSample],
topk: int = 3,
resize: Optional[int] = 224,
show: bool = False,
wait_time: int = 0,
draw_score=True,
show_dir=None):
if not show and show_dir is None:
return None
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
image = imread(input_)
if isinstance(input_, str):
# The image loaded from path is BGR format.
image = image[..., ::-1]
name = Path(input_).stem
else:
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_image_retrieval(
image,
data_sample,
self.prototype_dataset,
topk=topk,
resize=resize,
draw_score=draw_score,
show=show,
wait_time=wait_time,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(
self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False,
topk=1,
) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
matches = []
for match_score, sample_idx in zip(match_scores, indices):
sample = self.prototype_dataset.get_data_info(
sample_idx.item())
sample_idx = sample.pop('sample_idx')
matches.append({
'match_score': match_score,
'sample_idx': sample_idx,
'sample': sample
})
results.append(matches)
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Image Retrieval')
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import fnmatch
import os.path as osp
import re
import warnings
from os import PathLike
from pathlib import Path
from typing import List, Tuple, Union
from mmengine.config import Config
from modelindex.load_model_index import load
from modelindex.models.Model import Model
class ModelHub:
"""A hub to host the meta information of all pre-defined models."""
_models_dict = {}
__mmpretrain_registered = False
@classmethod
def register_model_index(cls,
model_index_path: Union[str, PathLike],
config_prefix: Union[str, PathLike, None] = None):
"""Parse the model-index file and register all models.
Args:
model_index_path (str | PathLike): The path of the model-index
file.
config_prefix (str | PathLike | None): The prefix of all config
file paths in the model-index file.
"""
model_index = load(str(model_index_path))
model_index.build_models_with_collections()
for metainfo in model_index.models:
model_name = metainfo.name.lower()
if metainfo.name in cls._models_dict:
raise ValueError(
'The model name {} is conflict in {} and {}.'.format(
model_name, osp.abspath(metainfo.filepath),
osp.abspath(cls._models_dict[model_name].filepath)))
metainfo.config = cls._expand_config_path(metainfo, config_prefix)
cls._models_dict[model_name] = metainfo
@classmethod
def get(cls, model_name):
"""Get the model's metainfo by the model name.
Args:
model_name (str): The name of model.
Returns:
modelindex.models.Model: The metainfo of the specified model.
"""
cls._register_mmpretrain_models()
# lazy load config
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
if metainfo is None:
raise ValueError(
f'Failed to find model "{model_name}". please use '
'`mmpretrain.list_models` to get all available names.')
if isinstance(metainfo.config, str):
metainfo.config = Config.fromfile(metainfo.config)
return metainfo
@staticmethod
def _expand_config_path(metainfo: Model,
config_prefix: Union[str, PathLike] = None):
if config_prefix is None:
config_prefix = osp.dirname(metainfo.filepath)
if metainfo.config is None or osp.isabs(metainfo.config):
config_path: str = metainfo.config
else:
config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
return config_path
@classmethod
def _register_mmpretrain_models(cls):
# register models in mmpretrain
if not cls.__mmpretrain_registered:
from importlib_metadata import distribution
root = distribution('mmpretrain').locate_file('mmpretrain')
model_index_path = root / '.mim' / 'model-index.yml'
ModelHub.register_model_index(
model_index_path, config_prefix=root / '.mim')
cls.__mmpretrain_registered = True
@classmethod
def has(cls, model_name):
"""Whether a model name is in the ModelHub."""
return model_name in cls._models_dict
def get_model(model: Union[str, Config],
pretrained: Union[str, bool] = False,
device=None,
device_map=None,
offload_folder=None,
url_mapping: Tuple[str, str] = None,
**kwargs):
"""Get a pre-defined model or create a model from config.
Args:
model (str | Config): The name of model, the config file path or a
config instance.
pretrained (bool | str): When use name to specify model, you can
use ``True`` to load the pre-defined pretrained weights. And you
can also use a string to specify the path or link of weights to
load. Defaults to False.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
device_map (str | dict | None): A map that specifies where each
submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every
submodule of it will be sent to the same device. You can use
`device_map="auto"` to automatically generate the device map.
Defaults to None.
offload_folder (str | None): If the `device_map` contains any value
`"disk"`, the folder where we will offload weights.
url_mapping (Tuple[str, str], optional): The mapping of pretrained
checkpoint link. For example, load checkpoint from a local dir
instead of download by ``('https://.*/', './checkpoint')``.
Defaults to None.
**kwargs: Other keyword arguments of the model config.
Returns:
mmengine.model.BaseModel: The result model.
Examples:
Get a ResNet-50 model and extract images feature:
>>> import torch
>>> from mmpretrain import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
... print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmpretrain import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
""" # noqa: E501
if device_map is not None:
from .utils import dispatch_model
dispatch_model._verify_require()
metainfo = None
if isinstance(model, Config):
config = copy.deepcopy(model)
if pretrained is True and 'load_from' in config:
pretrained = config.load_from
elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py':
config = Config.fromfile(model)
if pretrained is True and 'load_from' in config:
pretrained = config.load_from
elif isinstance(model, str):
metainfo = ModelHub.get(model)
config = metainfo.config
if pretrained is True and metainfo.weights is not None:
pretrained = metainfo.weights
else:
raise TypeError('model must be a name, a path or a Config object, '
f'but got {type(config)}')
if pretrained is True:
warnings.warn('Unable to find pre-defined checkpoint of the model.')
pretrained = None
elif pretrained is False:
pretrained = None
if kwargs:
config.merge_from_dict({'model': kwargs})
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
from mmengine.registry import DefaultScope
from mmpretrain.registry import MODELS
with DefaultScope.overwrite_default_scope('mmpretrain'):
model = MODELS.build(config.model)
dataset_meta = {}
if pretrained:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
from mmengine.runner import load_checkpoint
if url_mapping is not None:
pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained)
checkpoint = load_checkpoint(model, pretrained, map_location='cpu')
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmpretrain 1.x
dataset_meta = checkpoint['meta']['dataset_meta']
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls 0.x
dataset_meta = {'classes': checkpoint['meta']['CLASSES']}
if len(dataset_meta) == 0 and 'test_dataloader' in config:
from mmpretrain.registry import DATASETS
dataset_class = DATASETS.get(config.test_dataloader.dataset.type)
dataset_meta = getattr(dataset_class, 'METAINFO', {})
if device_map is not None:
model = dispatch_model(
model, device_map=device_map, offload_folder=offload_folder)
elif device is not None:
model.to(device)
model._dataset_meta = dataset_meta # save the dataset meta
model._config = config # save the config in the model
model._metainfo = metainfo # save the metainfo in the model
model.eval()
return model
def init_model(config, checkpoint=None, device=None, **kwargs):
"""Initialize a classifier from config file (deprecated).
It's only for compatibility, please use :func:`get_model` instead.
Args:
config (str | :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
**kwargs: Other keyword arguments of the model config.
Returns:
nn.Module: The constructed model.
"""
return get_model(config, checkpoint, device, **kwargs)
def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]:
"""List all models available in MMPretrain.
Args:
pattern (str | None): A wildcard pattern to match model names.
Defaults to None.
exclude_patterns (list | None): A list of wildcard patterns to
exclude names from the matched names. Defaults to None.
task (str | none): The evaluation task of the model.
Returns:
List[str]: a list of model names.
Examples:
List all models:
>>> from mmpretrain import list_models
>>> list_models()
List ResNet-50 models on ImageNet-1k dataset:
>>> from mmpretrain import list_models
>>> list_models('resnet*in1k')
['resnet50_8xb32_in1k',
'resnet50_8xb32-fp16_in1k',
'resnet50_8xb256-rsb-a1-600e_in1k',
'resnet50_8xb256-rsb-a2-300e_in1k',
'resnet50_8xb256-rsb-a3-100e_in1k']
List Swin-Transformer models trained from stratch and exclude
Swin-Transformer-V2 models:
>>> from mmpretrain import list_models
>>> list_models('swin', exclude_patterns=['swinv2', '*-pre'])
['swin-base_16xb64_in1k',
'swin-base_3rdparty_in1k',
'swin-base_3rdparty_in1k-384',
'swin-large_8xb8_cub-384px',
'swin-small_16xb64_in1k',
'swin-small_3rdparty_in1k',
'swin-tiny_16xb64_in1k',
'swin-tiny_3rdparty_in1k']
List all EVA models for image classification task.
>>> from mmpretrain import list_models
>>> list_models('eva', task='Image Classification')
['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px',
'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px',
'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px',
'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px',
'eva-l-p14_mim-pre_3rdparty_in1k-196px',
'eva-l-p14_mim-pre_3rdparty_in1k-336px']
"""
ModelHub._register_mmpretrain_models()
matches = set(ModelHub._models_dict.keys())
if pattern is not None:
# Always match keys with any postfix.
matches = set(fnmatch.filter(matches, pattern + '*'))
exclude_patterns = exclude_patterns or []
for exclude_pattern in exclude_patterns:
exclude = set(fnmatch.filter(matches, exclude_pattern + '*'))
matches = matches - exclude
if task is not None:
task_matches = []
for key in matches:
metainfo = ModelHub._models_dict[key]
if metainfo.results is None and task == 'null':
task_matches.append(key)
elif metainfo.results is None:
continue
elif task in [result.task for result in metainfo.results]:
task_matches.append(key)
matches = task_matches
return sorted(list(matches))
def inference_model(model, *args, **kwargs):
"""Inference an image with the inferencer.
Automatically select inferencer to inference according to the type of
model. It's a shortcut for a quick start, and for advanced usage, please
use the correspondding inferencer class.
Here is the mapping from task to inferencer:
- Image Classification: :class:`ImageClassificationInferencer`
- Image Retrieval: :class:`ImageRetrievalInferencer`
- Image Caption: :class:`ImageCaptionInferencer`
- Visual Question Answering: :class:`VisualQuestionAnsweringInferencer`
- Visual Grounding: :class:`VisualGroundingInferencer`
- Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer`
- Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer`
- NLVR: :class:`NLVRInferencer`
Args:
model (BaseModel | str | Config): The loaded model, the model
name or the config of the model.
*args: Positional arguments to call the inferencer.
**kwargs: Other keyword arguments to initialize and call the
correspondding inferencer.
Returns:
result (dict): The inference results.
""" # noqa: E501
from mmengine.model import BaseModel
if isinstance(model, BaseModel):
metainfo = getattr(model, '_metainfo', None)
else:
metainfo = ModelHub.get(model)
from inspect import signature
from .image_caption import ImageCaptionInferencer
from .image_classification import ImageClassificationInferencer
from .image_retrieval import ImageRetrievalInferencer
from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
TextToImageRetrievalInferencer)
from .nlvr import NLVRInferencer
from .visual_grounding import VisualGroundingInferencer
from .visual_question_answering import VisualQuestionAnsweringInferencer
task_mapping = {
'Image Classification': ImageClassificationInferencer,
'Image Retrieval': ImageRetrievalInferencer,
'Image Caption': ImageCaptionInferencer,
'Visual Question Answering': VisualQuestionAnsweringInferencer,
'Visual Grounding': VisualGroundingInferencer,
'Text-To-Image Retrieval': TextToImageRetrievalInferencer,
'Image-To-Text Retrieval': ImageToTextRetrievalInferencer,
'NLVR': NLVRInferencer,
}
inferencer_type = None
if metainfo is not None and metainfo.results is not None:
tasks = set(result.task for result in metainfo.results)
inferencer_type = [
task_mapping.get(task) for task in tasks if task in task_mapping
]
if len(inferencer_type) > 1:
inferencer_names = [cls.__name__ for cls in inferencer_type]
warnings.warn('The model supports multiple tasks, auto select '
f'{inferencer_names[0]}, you can also use other '
f'inferencer {inferencer_names} directly.')
inferencer_type = inferencer_type[0]
if inferencer_type is None:
raise NotImplementedError('No available inferencer for the model')
init_kwargs = {
k: kwargs.pop(k)
for k in list(kwargs)
if k in signature(inferencer_type).parameters.keys()
}
inferencer = inferencer_type(model, **init_kwargs)
return inferencer(*args, **kwargs)[0]
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import mmengine
import numpy as np
import torch
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import BaseDataset, Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample
from mmpretrain.utils import track
from .base import BaseInferencer
from .base import InputType as ImageType
from .base import ModelType
from .model import list_models
def filter_transforms(transforms: list, data_info: dict):
"""Filter pipeline to avoid KeyError with partial data info."""
data_info = deepcopy(data_info)
filtered_transforms = []
for t in transforms:
try:
data_info = t(data_info)
filtered_transforms.append(t)
except KeyError:
pass
return filtered_transforms
class TextToImageRetrievalInferencer(BaseInferencer):
"""The inferencer for text to image retrieval.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``TextToImageRetrievalInferencer.list_models()`` and you can also
query it in :doc:`/modelzoo_statistics`.
prototype (str | list | dict | DataLoader | BaseDataset): The images to
be retrieved. It can be the following types:
- str: The directory of the the images.
- list: A list of path of the images.
- dict: A config dict of the a prototype dataset.
- BaseDataset: A prototype dataset.
- DataLoader: A data loader to load the prototype data.
prototype_cache (str, optional): The path of the generated prototype
features. If exists, directly load the cache instead of re-generate
the prototype features. If not exists, save the generated features
to the path. Defaults to None.
fast_match (bool): Some algorithms will record extra image features for
further matching, which may consume large memory, set True to avoid
this behavior. Defaults to True.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
>>> from mmpretrain import TextToImageRetrievalInferencer
>>> inferencer = TextToImageRetrievalInferencer(
... 'blip-base_3rdparty_retrieval',
... prototype='./demo/',
... prototype_cache='t2i_retri.pth')
>>> inferencer('A cat and a dog.')[0]
{'match_score': tensor(0.3855, device='cuda:0'),
'sample_idx': 1,
'sample': {'img_path': './demo/cat-dog.png'}}
""" # noqa: E501
visualize_kwargs: set = {
'draw_score', 'show_dir', 'show', 'wait_time', 'figsize', 'topk'
}
postprocess_kwargs: set = {'topk'}
def __init__(self,
model: ModelType,
prototype,
prototype_cache=None,
fast_match=True,
prepare_batch_size=8,
pretrained: Union[bool, str] = True,
device: Union[str, torch.device, None] = None,
**kwargs) -> None:
super().__init__(
model=model, pretrained=pretrained, device=device, **kwargs)
self.img_pipeline, self.text_pipeline = self.pipeline
if hasattr(self.model, 'fast_match'):
self.model.fast_match = fast_match
self.prototype_dataset = self._prepare_prototype(
prototype, prototype_cache, batch_size=prepare_batch_size)
def _prepare_prototype(self, prototype, cache=None, batch_size=8):
from mmengine.dataset import DefaultSampler
from torch.utils.data import DataLoader
def build_dataloader(dataset):
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=default_collate,
sampler=DefaultSampler(dataset, shuffle=False),
persistent_workers=False,
)
if isinstance(prototype, str):
# A directory path of images
prototype = dict(
type='CustomDataset', with_label=False, data_root=prototype)
if isinstance(prototype, list):
test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
dataset = BaseDataset(
lazy_init=True, serialize_data=False, pipeline=test_pipeline)
dataset.data_list = [{
'sample_idx': i,
'img_path': file
} for i, file in enumerate(prototype)]
dataset._fully_initialized = True
dataloader = build_dataloader(dataset)
elif isinstance(prototype, dict):
# A config of dataset
from mmpretrain.registry import DATASETS
test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
prototype.setdefault('pipeline', test_pipeline)
dataset = DATASETS.build(prototype)
dataloader = build_dataloader(dataset)
elif isinstance(prototype, list):
test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
dataset = BaseDataset(
lazy_init=True, serialize_data=False, pipeline=test_pipeline)
dataset.data_list = [{
'sample_idx': i,
'img_path': file
} for i, file in enumerate(prototype)]
dataset._fully_initialized = True
dataloader = build_dataloader(dataset)
elif isinstance(prototype, DataLoader):
dataset = prototype.dataset
dataloader = prototype
elif isinstance(prototype, BaseDataset):
dataset = prototype
dataloader = build_dataloader(dataset)
else:
raise TypeError(f'Unsupported prototype type {type(prototype)}.')
if cache is not None and Path(cache).exists():
self.prototype = torch.load(cache)
else:
prototype = []
for data_batch in track(dataloader, 'Prepare prototype...'):
with torch.no_grad():
data_batch = self.model.data_preprocessor(
data_batch, False)
feats = self.model._run_forward(data_batch, mode='tensor')
prototype.append(feats)
prototype = {
k: torch.cat([d[k] for d in prototype])
for k in prototype[0]
}
self.prototype = prototype
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
if cache is None:
logger.info('The prototype has been prepared, you can use '
'`save_prototype` to dump it into a pickle '
'file for the future usage.')
elif not Path(cache).exists():
self.save_prototype(cache)
logger.info(f'The prototype has been saved at {cache}.')
return dataset
def save_prototype(self, path):
torch.save(self.prototype, path)
def __call__(self,
inputs: ImageType,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (str | array | list): The image path or array, or a list of
images.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the long edge of the image to the
specified length before visualization. Defaults to None.
draw_score (bool): Whether to draw the match scores.
Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
Returns:
list: The inference results.
"""
return super().__call__(inputs, return_datasamples, batch_size,
**kwargs)
@torch.no_grad()
def forward(self, data: dict, **kwargs):
"""Feed the inputs to the model."""
data = self.model.data_preprocessor(data, False)
data_samples = data['data_samples']
feats = self.prototype.copy()
feats.update(self.model.extract_feat(data_samples=data_samples))
return self.model.predict_all(feats, data_samples, cal_i2t=False)[0]
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
text_info = {'text': 'example'}
img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
return img_pipeline, text_pipeline
def preprocess(self, inputs: List[str], batch_size: int = 1):
def process_text(input_: str):
return self.text_pipeline({'text': input_})
chunked_data = self._get_chunk_data(
map(process_text, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self,
ori_inputs: List[str],
preds: List[DataSample],
topk: int = 3,
figsize: Tuple[int, int] = (16, 9),
show: bool = False,
wait_time: int = 0,
draw_score=True,
show_dir=None):
if not show and show_dir is None:
return None
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (text, data_sample) in enumerate(zip(ori_inputs, preds)):
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_t2i_retrieval(
text,
data_sample,
self.prototype_dataset,
topk=topk,
fig_cfg=dict(figsize=figsize),
draw_score=draw_score,
show=show,
wait_time=wait_time,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(
self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False,
topk=1,
) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
matches = []
for match_score, sample_idx in zip(match_scores, indices):
sample = self.prototype_dataset.get_data_info(
sample_idx.item())
sample_idx = sample.pop('sample_idx')
matches.append({
'match_score': match_score,
'sample_idx': sample_idx,
'sample': sample
})
results.append(matches)
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Text-To-Image Retrieval')
class ImageToTextRetrievalInferencer(BaseInferencer):
"""The inferencer for image to text retrieval.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``ImageToTextRetrievalInferencer.list_models()`` and you can
also query it in :doc:`/modelzoo_statistics`.
prototype (str | list | dict | DataLoader, BaseDataset): The images to
be retrieved. It can be the following types:
- str: The file path to load the string list.
- list: A list of string.
prototype_cache (str, optional): The path of the generated prototype
features. If exists, directly load the cache instead of re-generate
the prototype features. If not exists, save the generated features
to the path. Defaults to None.
fast_match (bool): Some algorithms will record extra image features for
further matching, which may consume large memory, set True to avoid
this behavior. Defaults to True.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
>>> from mmpretrain import ImageToTextRetrievalInferencer
>>> inferencer = ImageToTextRetrievalInferencer(
... 'blip-base_3rdparty_retrieval',
... prototype=['cat', 'dog', 'snake', 'bird'],
... prototype_cache='i2t_retri.pth')
>>> inferencer('demo/bird.JPEG')[0]
{'match_score': tensor(0.3855, device='cuda:0'),
'sample_idx': 1,
'sample': {'img_path': './demo/cat-dog.png'}}
""" # noqa: E501
visualize_kwargs: set = {
'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
}
postprocess_kwargs: set = {'topk'}
def __init__(self,
model: ModelType,
prototype,
prototype_cache=None,
fast_match=True,
prepare_batch_size=8,
pretrained: Union[bool, str] = True,
device: Union[str, torch.device, None] = None,
**kwargs) -> None:
super().__init__(
model=model, pretrained=pretrained, device=device, **kwargs)
self.img_pipeline, self.text_pipeline = self.pipeline
if hasattr(self.model, 'fast_match'):
self.model.fast_match = fast_match
self.prototype_dataset = self._prepare_prototype(
prototype, cache=prototype_cache, batch_size=prepare_batch_size)
def _prepare_prototype(self, prototype, cache=None, batch_size=8):
from mmengine.dataset import DefaultSampler
from torch.utils.data import DataLoader
def build_dataloader(dataset):
return DataLoader(
[
self.text_pipeline({
'sample_idx': i,
'text': text
}) for i, text in enumerate(dataset)
],
batch_size=batch_size,
collate_fn=default_collate,
sampler=DefaultSampler(dataset, shuffle=False),
persistent_workers=False,
)
if isinstance(prototype, str):
# A file path of a list of string
dataset = mmengine.list_from_file(prototype)
elif mmengine.utils.is_seq_of(prototype, str):
dataset = prototype
else:
raise TypeError(f'Unsupported prototype type {type(prototype)}.')
dataloader = build_dataloader(dataset)
if cache is not None and Path(cache).exists():
self.prototype = torch.load(cache)
else:
prototype = []
for data_batch in track(dataloader, 'Prepare prototype...'):
with torch.no_grad():
data_batch = self.model.data_preprocessor(
data_batch, False)
feats = self.model._run_forward(data_batch, mode='tensor')
prototype.append(feats)
prototype = {
k: torch.cat([d[k] for d in prototype])
for k in prototype[0]
}
self.prototype = prototype
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
if cache is None:
logger.info('The prototype has been prepared, you can use '
'`save_prototype` to dump it into a pickle '
'file for the future usage.')
elif not Path(cache).exists():
self.save_prototype(cache)
logger.info(f'The prototype has been saved at {cache}.')
return dataset
def save_prototype(self, path):
torch.save(self.prototype, path)
def __call__(self,
inputs: ImageType,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (str | array | list): The image path or array, or a list of
images.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the long edge of the image to the
specified length before visualization. Defaults to None.
draw_score (bool): Whether to draw the match scores.
Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
Returns:
list: The inference results.
"""
return super().__call__(inputs, return_datasamples, batch_size,
**kwargs)
@torch.no_grad()
def forward(self, data: dict, **kwargs):
"""Feed the inputs to the model."""
data = self.model.data_preprocessor(data, False)
feats = self.prototype.copy()
feats.update(self.model.extract_feat(images=data['images']))
return self.model.predict_all(
feats, data['data_samples'], cal_t2i=False)[0]
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
text_info = {'text': 'example'}
img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
return img_pipeline, text_pipeline
def preprocess(self, inputs: List[ImageType], batch_size: int = 1):
def load_image(input_):
img = imread(input_)
if img is None:
raise ValueError(f'Failed to read image {input_}.')
return dict(
img=img,
img_shape=img.shape[:2],
ori_shape=img.shape[:2],
)
pipeline = Compose([load_image, self.img_pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self,
ori_inputs: List[ImageType],
preds: List[DataSample],
topk: int = 3,
resize: Optional[int] = 224,
show: bool = False,
wait_time: int = 0,
draw_score=True,
show_dir=None):
if not show and show_dir is None:
return None
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
image = imread(input_)
if isinstance(input_, str):
# The image loaded from path is BGR format.
image = image[..., ::-1]
name = Path(input_).stem
else:
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_i2t_retrieval(
image,
data_sample,
self.prototype_dataset,
topk=topk,
resize=resize,
draw_score=draw_score,
show=show,
wait_time=wait_time,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(
self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False,
topk=1,
) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
matches = []
for match_score, sample_idx in zip(match_scores, indices):
text = self.prototype_dataset[sample_idx.item()]
matches.append({
'match_score': match_score,
'sample_idx': sample_idx,
'text': text
})
results.append(matches)
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Image-To-Text Retrieval')
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample
from .base import BaseInferencer
from .model import list_models
InputType = Tuple[Union[str, np.ndarray], Union[str, np.ndarray], str]
InputsType = Union[List[InputType], InputType]
class NLVRInferencer(BaseInferencer):
"""The inferencer for Natural Language for Visual Reasoning.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``NLVRInferencer.list_models()`` and you can also
query it in :doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
"""
visualize_kwargs: set = {
'resize', 'draw_score', 'show', 'show_dir', 'wait_time'
}
def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (tuple, List[tuple]): The input data tuples, every tuple
should include three items (left image, right image, text).
The image can be a path or numpy array.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the short edge of the image to the
specified length before visualization. Defaults to None.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
Returns:
list: The inference results.
"""
assert isinstance(inputs, (tuple, list))
if isinstance(inputs, tuple):
inputs = [inputs]
for input_ in inputs:
assert isinstance(input_, tuple)
assert len(input_) == 3
return super().__call__(
inputs,
return_datasamples=return_datasamples,
batch_size=batch_size,
**kwargs)
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
assert test_pipeline_cfg[0]['type'] == 'ApplyToList'
list_pipeline = deepcopy(test_pipeline_cfg[0])
if list_pipeline.scatter_key == 'img_path':
# Remove `LoadImageFromFile`
list_pipeline.transforms.pop(0)
list_pipeline.scatter_key = 'img'
test_pipeline = Compose(
[TRANSFORMS.build(list_pipeline)] +
[TRANSFORMS.build(t) for t in test_pipeline_cfg[1:]])
return test_pipeline
def preprocess(self, inputs: InputsType, batch_size: int = 1):
def load_image(input_):
img1 = imread(input_[0])
img2 = imread(input_[1])
text = input_[2]
if img1 is None:
raise ValueError(f'Failed to read image {input_[0]}.')
if img2 is None:
raise ValueError(f'Failed to read image {input_[1]}.')
return dict(
img=[img1, img2],
img_shape=[img1.shape[:2], img2.shape[:2]],
ori_shape=[img1.shape[:2], img2.shape[:2]],
text=text,
)
pipeline = Compose([load_image, self.pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def postprocess(self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
pred_scores = data_sample.pred_score
pred_score = float(torch.max(pred_scores).item())
pred_label = torch.argmax(pred_scores).item()
result = {
'pred_scores': pred_scores.detach().cpu().numpy(),
'pred_label': pred_label,
'pred_score': pred_score,
}
results.append(result)
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='NLVR')
# Copyright (c) OpenMMLab. All rights reserved.
import os
from collections import defaultdict
from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmpretrain.utils import require
@require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/')
@require('accelerate')
def dispatch_model(
model,
device_map: Union[str, dict],
max_memory: Optional[dict] = None,
no_split_module_classes: Optional[List[str]] = None,
offload_folder: str = None,
offload_buffers: bool = False,
preload_module_classes: Optional[List[str]] = None,
):
"""Split and dispatch a model across devices.
The function depends on the `accelerate` package. Refers to
https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling
Args:
model (torch.nn.Module): The model to dispatch.
device_map (str | dict | None): A map that specifies where each
submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every
submodule of it will be sent to the same device. You can use
`device_map="auto"` to automatically generate the device map.
Defaults to None.
max_memory (dict | None): A dictionary device identifier to maximum
memory. Will default to the maximum memory available for each GPU
and the available CPU RAM if unset. Defaults to None.
no_split_module_classes (List[str] | None): A list of layer class names
that should never be split across device (for instance any layer
that has a residual connection). If None, try to get the settings
from the model class. Defaults to None.
offload_folder (str | None): If the `device_map` contains any value
`"disk"`, the folder where we will offload weights.
offload_buffers (bool): In the layers that are offloaded on the CPU
or the hard drive, whether or not to offload the buffers as
well as the parameters. Defaults to False.
preload_module_classes (List[str] | None): A list of classes whose
instances should load all their weights (even in the submodules) at
the beginning of the forward. This should only be used for classes
that have submodules which are registered but not called directly
during the forward, for instance if a `dense` linear layer is
registered, but at forward, `dense.weight` and `dense.bias` are
used in some operations instead of calling `dense` directly.
Defaults to None.
"""
from accelerate import dispatch_model, infer_auto_device_map
# Check valid device_map string.
valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential']
if isinstance(device_map, str) and device_map not in valid_map_option:
raise ValueError('If passing a string for `device_map`, please choose '
f'from {valid_map_option}.')
# Generate device map automatically
if isinstance(device_map, str):
if no_split_module_classes is None:
no_split_module_classes = getattr(model, '_no_split_modules', None)
if no_split_module_classes is None:
raise ValueError(f'{model.__class__.__name__} does not support '
f"`device_map='{device_map}'` yet.")
if device_map != 'sequential':
from accelerate.utils import get_balanced_memory
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=None,
low_zero=(device_map == 'balanced_low_0'),
)
max_memory[0] *= 0.9
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=None,
)
if 'disk' in device_map.values():
if offload_folder is None:
raise ValueError(
'The current `device_map` had weights offloaded to the disk. '
'Please provide an `offload_folder` for them.')
os.makedirs(offload_folder, exist_ok=True)
main_device = next(
(d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu')
model = dispatch_model(
model,
device_map=device_map,
main_device=main_device,
offload_dir=offload_folder,
offload_buffers=offload_buffers,
preload_module_classes=preload_module_classes,
)
if hasattr(model, 'data_preprocessor'):
model.data_preprocessor._device = torch.device(main_device)
return model
@contextmanager
def init_empty_weights(include_buffers: bool = False):
"""A context manager under which models are initialized with all parameters
on the meta device.
With this context manager, we can create an empty model. Useful when just
initializing the model would blow the available RAM.
Besides move the parameters to meta device, this method will also avoid
load checkpoint from `mmengine.runner.load_checkpoint` and
`transformers.PreTrainedModel.from_pretrained`.
Modified from https://github.com/huggingface/accelerate
Args:
include_buffers (bool): Whether put all buffers on the meta device
during initialization.
"""
device = torch.device('meta')
# move parameter and buffer to meta device
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer
# See https://github.com/huggingface/accelerate/pull/699
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ['empty', 'zeros', 'ones', 'full']
}
def register_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs)
def register_buffer(module, name, buffer, *args, **kwargs):
old_register_buffer(module, name, buffer, *args, **kwargs)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs['device'] = device
return fn(*args, **kwargs)
return wrapper
# Patch load_checkpoint
import mmengine.runner.checkpoint as mmengine_load
old_load_checkpoint = mmengine_load.load_checkpoint
def patch_load_checkpoint(*args, **kwargs):
return {}
# Patch transformers from pretrained
try:
from transformers import PreTrainedModel
from transformers.models.auto.auto_factory import (AutoConfig,
_BaseAutoModelClass)
with_transformers = True
except ImportError:
with_transformers = False
@classmethod
def patch_auto_model(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path,
*model_args, **kwargs)
return cls.from_config(cfg)
@classmethod
def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path,
*model_args, **kwargs)
return cls(cfg)
if with_transformers:
old_pretrained_model = PreTrainedModel.from_pretrained
old_auto_model = _BaseAutoModelClass.from_pretrained
try:
nn.Module.register_parameter = register_parameter
mmengine_load.load_checkpoint = patch_load_checkpoint
if with_transformers:
PreTrainedModel.from_pretrained = patch_pretrained_model
_BaseAutoModelClass.from_pretrained = patch_auto_model
if include_buffers:
nn.Module.register_buffer = register_buffer
for func in tensor_constructors_to_patch.keys():
tensor_constructor = patch_tensor_constructor(
getattr(torch, func))
setattr(torch, func, tensor_constructor)
yield
finally:
nn.Module.register_parameter = old_register_parameter
mmengine_load.load_checkpoint = old_load_checkpoint
if with_transformers:
PreTrainedModel.from_pretrained = old_pretrained_model
_BaseAutoModelClass.from_pretrained = old_auto_model
if include_buffers:
nn.Module.register_buffer = old_register_buffer
for func, ori in tensor_constructors_to_patch.items():
setattr(torch, func, ori)
def compute_module_sizes(
model: nn.Module,
dtype: Union[str, torch.dtype, None] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None):
"""Compute the size of each submodule of a given model."""
def get_dtype(dtype):
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
if dtype is not None:
assert issubclass(dtype, torch.dtype)
return dtype
def dtype_bytes(dtype: torch.dtype):
if dtype is torch.bool:
return 1
if dtype.is_floating_point:
return torch.finfo(dtype).bits / 8
else:
return torch.iinfo(dtype).bits / 8
if dtype is not None:
dtype = get_dtype(dtype)
dtype_size = dtype_bytes(dtype)
if special_dtypes is not None:
special_dtypes = {
key: dtype_bytes(dtype)
for key, dtype in special_dtypes.items()
}
module_sizes = defaultdict(int)
for name, tensor in chain(
model.named_parameters(recurse=True),
model.named_buffers(recurse=True)):
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes[name]
elif dtype is None:
size = tensor.numel() * tensor.element_size()
else:
size = tensor.numel() * min(dtype_size, tensor.element_size())
name_parts = name.split('.')
for idx in range(len(name_parts) + 1):
module_sizes['.'.join(name_parts[:idx])] += size
return module_sizes
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample
from .base import BaseInferencer
from .model import list_models
class VisualGroundingInferencer(BaseInferencer):
"""The inferencer for visual grounding.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``VisualGroundingInferencer.list_models()`` and you can also
query it in :doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
>>> from mmpretrain import VisualGroundingInferencer
>>> inferencer = VisualGroundingInferencer('ofa-base_3rdparty_refcoco')
>>> inferencer('demo/cat-dog.png', 'dog')[0]
{'pred_bboxes': tensor([[ 36.6000, 29.6000, 355.8000, 395.2000]])}
""" # noqa: E501
visualize_kwargs: set = {
'resize', 'show', 'show_dir', 'wait_time', 'line_width', 'bbox_color'
}
def __call__(self,
images: Union[str, np.ndarray, list],
texts: Union[str, list],
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs) -> dict:
"""Call the inferencer.
Args:
images (str | array | list): The image path or array, or a list of
images.
texts (str | list): The text to do visual grounding.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the short edge of the image to the
specified length before visualization. Defaults to None.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
line_width (int): The line width of the bbox. Defaults to 3.
bbox_color (str | tuple): The color of the bbox.
Defaults to 'green'.
Returns:
list: The inference results.
"""
if not isinstance(images, (list, tuple)):
assert isinstance(texts, str)
inputs = [{'img': images, 'text': texts}]
else:
inputs = []
for i in range(len(images)):
input_ = {'img': images[i], 'text': texts[i]}
inputs.append(input_)
return super().__call__(inputs, return_datasamples, batch_size,
**kwargs)
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline
def preprocess(self, inputs: List[dict], batch_size: int = 1):
def load_image(input_: dict):
img = imread(input_['img'])
if img is None:
raise ValueError(f'Failed to read image {input_}.')
return {**input_, 'img': img}
pipeline = Compose([load_image, self.pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self,
ori_inputs: List[dict],
preds: List[DataSample],
show: bool = False,
wait_time: int = 0,
resize: Optional[int] = None,
line_width: int = 3,
bbox_color: Union[str, tuple] = 'green',
show_dir=None):
if not show and show_dir is None:
return None
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
image = imread(input_['img'])
if isinstance(input_['img'], str):
# The image loaded from path is BGR format.
image = image[..., ::-1]
name = Path(input_['img']).stem
else:
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_visual_grounding(
image,
data_sample,
resize=resize,
show=show,
wait_time=wait_time,
line_width=line_width,
bbox_color=bbox_color,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
results.append({'pred_bboxes': data_sample.get('pred_bboxes')})
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Visual Grounding')
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample
from .base import BaseInferencer
from .model import list_models
class VisualQuestionAnsweringInferencer(BaseInferencer):
"""The inferencer for visual question answering.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``VisualQuestionAnsweringInferencer.list_models()`` and you can
also query it in :doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
Example:
>>> from mmpretrain import VisualQuestionAnsweringInferencer
>>> inferencer = VisualQuestionAnsweringInferencer('ofa-base_3rdparty-zeroshot_vqa')
>>> inferencer('demo/cat-dog.png', "What's the animal next to the dog?")[0]
{'question': "What's the animal next to the dog?", 'pred_answer': 'cat'}
""" # noqa: E501
visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
def __call__(self,
images: Union[str, np.ndarray, list],
questions: Union[str, list],
return_datasamples: bool = False,
batch_size: int = 1,
objects: Optional[List[str]] = None,
**kwargs) -> dict:
"""Call the inferencer.
Args:
images (str | array | list): The image path or array, or a list of
images.
questions (str | list): The question to the correspondding image.
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
objects (List[List[str]], optional): Some algorithms like OFA
fine-tuned VQA models requires extra object description list
for every image. Defaults to None.
resize (int, optional): Resize the short edge of the image to the
specified length before visualization. Defaults to None.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
Returns:
list: The inference results.
"""
if not isinstance(images, (list, tuple)):
assert isinstance(questions, str)
inputs = [{'img': images, 'question': questions}]
if objects is not None:
assert isinstance(objects[0], str)
inputs[0]['objects'] = objects
else:
inputs = []
for i in range(len(images)):
input_ = {'img': images[i], 'question': questions[i]}
if objects is not None:
input_['objects'] = objects[i]
inputs.append(input_)
return super().__call__(inputs, return_datasamples, batch_size,
**kwargs)
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline
def preprocess(self, inputs: List[dict], batch_size: int = 1):
def load_image(input_: dict):
img = imread(input_['img'])
if img is None:
raise ValueError(f'Failed to read image {input_}.')
return {**input_, 'img': img}
pipeline = Compose([load_image, self.pipeline])
chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
def visualize(self,
ori_inputs: List[dict],
preds: List[DataSample],
show: bool = False,
wait_time: int = 0,
resize: Optional[int] = None,
show_dir=None):
if not show and show_dir is None:
return None
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
image = imread(input_['img'])
if isinstance(input_['img'], str):
# The image loaded from path is BGR format.
image = image[..., ::-1]
name = Path(input_['img']).stem
else:
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_vqa(
image,
data_sample,
resize=resize,
show=show,
wait_time=wait_time,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False) -> dict:
if return_datasamples:
return preds
results = []
for data_sample in preds:
results.append({
'question': data_sample.get('question'),
'pred_answer': data_sample.get('pred_answer'),
})
return results
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List all available model names.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Visual Question Answering')
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = CIFAR10
data_preprocessor = dict(
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False)
train_pipeline = [
dict(type=RandomCrop, crop_size=32, padding=4),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/cifar10',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/cifar10/',
split='test',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, ))
test_dataloader = val_dataloader
test_evaluator = val_evaluator
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile,
PackInputs, RandomCrop, RandomFlip, Resize)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = CUB
data_preprocessor = dict(
num_classes=200,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=Resize, scale=510),
dict(type=RandomCrop, crop_size=384),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=Resize, scale=510),
dict(type=CenterCrop, crop_size=384),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
split='test',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, ))
test_dataloader = val_dataloader
test_evaluator = val_evaluator
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (ImageNet21k, LoadImageFromFile, PackInputs,
RandomFlip, RandomResizedCrop)
# dataset settings
dataset_type = ImageNet21k
data_preprocessor = dict(
num_classes=21842,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=RandomResizedCrop, scale=224),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet21k',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet,
LoadImageFromFile, PackInputs, RandomErasing,
RandomFlip, RandomResizedCrop, ResizeEdge)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = ImageNet
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=RandomResizedCrop, scale=224, backend='pillow'),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(
type=AutoAugment,
policies='imagenet',
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(
type=RandomErasing,
erase_prob=0.2,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'),
dict(type=CenterCrop, crop_size=224),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler, default_collate
from mmpretrain.datasets import (BEiTMaskGenerator, ColorJitter, ImageNet,
LoadImageFromFile, PackInputs, RandomFlip,
RandomResizedCropAndInterpolationWithTwoPic)
from mmpretrain.models import TwoNormDataPreprocessor
dataset_type = ImageNet
data_root = 'data/imagenet/'
data_preprocessor = dict(
type=TwoNormDataPreprocessor,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
second_mean=[127.5, 127.5, 127.5],
second_std=[127.5, 127.5, 127.5],
to_rgb=True)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4,
hue=0.),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(
type=RandomResizedCropAndInterpolationWithTwoPic,
size=224,
second_size=224,
interpolation='bicubic',
second_interpolation='bicubic',
scale=(0.2, 1.0)),
dict(
type=BEiTMaskGenerator,
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=75,
min_num_patches=16),
dict(type=PackInputs)
]
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
PackInputs, RandomFlip, RandomResizedCrop,
ResizeEdge)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = ImageNet
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=RandomResizedCrop, scale=224),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=ResizeEdge, scale=256, edge='short'),
dict(type=CenterCrop, crop_size=224),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
PackInputs, RandomFlip, RandomResizedCrop,
ResizeEdge)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = ImageNet
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=RandomResizedCrop, scale=224, backend='pillow'),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'),
dict(type=CenterCrop, crop_size=224),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment