Commit cbc25585 authored by limm's avatar limm
Browse files

add mmpretrain/ part

parent 1baf0566
Pipeline #2801 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class MiniGPT4Dataset(BaseDataset):
"""Dataset for training MiniGPT4.
MiniGPT4 dataset directory:
minigpt4_dataset
├── image
│ ├── id0.jpg
│ │── id1.jpg
│ │── id2.jpg
│ └── ...
└── conversation_data.json
The structure of conversation_data.json:
[
// English data
{
"id": str(id0),
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
###Answer: [Answer content]"
},
// Chinese data
{
"id": str(id1),
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
###答:[Answer content]"
},
...
]
Args:
data_root (str): The root directory for ``ann_file`` and ``image``.
ann_file (str): Conversation file path.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def load_data_list(self) -> List[dict]:
file_backend = get_file_backend(self.data_root)
conversation_path = file_backend.join_path(self.data_root,
self.ann_file)
conversation = mmengine.load(conversation_path)
img_ids = {}
n = 0
for conv in conversation:
img_id = conv['id']
if img_id not in img_ids.keys():
img_ids[img_id] = n
n += 1
img_root = file_backend.join_path(self.data_root, 'image')
data_list = []
for conv in conversation:
img_file = '{}.jpg'.format(conv['id'])
chat_content = conv['conversation']
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
data_info = {
'image_id': img_ids[conv['id']],
'img_path': file_backend.join_path(img_root, img_file),
'chat_content': chat_content,
'lang': lang,
}
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import codecs
from typing import List, Optional
from urllib.parse import urljoin
import mmengine.dist as dist
import numpy as np
import torch
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
from mmengine.logging import MMLogger
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
from .utils import (download_and_extract_archive, open_maybe_compressed_file,
rm_suffix)
@DATASETS.register_module()
class MNIST(BaseDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
Args:
data_root (str): The root directory of the MNIST Dataset.
split (str, optional): The dataset split, supports "train" and "test".
Default to "train".
metainfo (dict, optional): Meta information for dataset, such as
categories information. Defaults to None.
download (bool): Whether to download the dataset if not exists.
Defaults to True.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
""" # noqa: E501
url_prefix = 'http://yann.lecun.com/exdb/mnist/'
# train images and labels
train_list = [
['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
]
# test images and labels
test_list = [
['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
]
METAINFO = {'classes': MNIST_CATEGORITES}
def __init__(self,
data_root: str = '',
split: str = 'train',
metainfo: Optional[dict] = None,
download: bool = True,
data_prefix: str = '',
test_mode: bool = False,
**kwargs):
splits = ['train', 'test']
assert split in splits, \
f"The split must be one of {splits}, but get '{split}'"
self.split = split
# To handle the BC-breaking
if split == 'train' and test_mode:
logger = MMLogger.get_current_instance()
logger.warning('split="train" but test_mode=True. '
'The training set will be used.')
if not data_root and not data_prefix:
raise RuntimeError('Please set ``data_root`` to'
'specify the dataset path')
self.download = download
super().__init__(
# The MNIST dataset doesn't need specify annotation file
ann_file='',
metainfo=metainfo,
data_root=data_root,
data_prefix=dict(root=data_prefix),
test_mode=test_mode,
**kwargs)
def load_data_list(self):
"""Load images and ground truth labels."""
root = self.data_prefix['root']
backend = get_file_backend(root, enable_singleton=True)
if dist.is_main_process() and not self._check_exists():
if not isinstance(backend, LocalBackend):
raise RuntimeError(f'The dataset on {root} is not integrated, '
f'please manually handle it.')
if self.download:
self._download()
else:
raise RuntimeError(
f'Cannot find {self.__class__.__name__} dataset in '
f"{self.data_prefix['root']}, you can specify "
'`download=True` to download automatically.')
dist.barrier()
assert self._check_exists(), \
'Download failed or shared storage is unavailable. Please ' \
f'download the dataset manually through {self.url_prefix}.'
if not self.test_mode:
file_list = self.train_list
else:
file_list = self.test_list
# load data from SN3 files
imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0])))
gt_labels = read_label_file(
join_path(root, rm_suffix(file_list[1][0])))
data_infos = []
for img, gt_label in zip(imgs, gt_labels):
gt_label = np.array(gt_label, dtype=np.int64)
info = {'img': img.numpy(), 'gt_label': gt_label}
data_infos.append(info)
return data_infos
def _check_exists(self):
"""Check the exists of data files."""
root = self.data_prefix['root']
for filename, _ in (self.train_list + self.test_list):
# get extracted filename of data
extract_filename = rm_suffix(filename)
fpath = join_path(root, extract_filename)
if not exists(fpath):
return False
return True
def _download(self):
"""Download and extract data files."""
root = self.data_prefix['root']
for filename, md5 in (self.train_list + self.test_list):
url = urljoin(self.url_prefix, filename)
download_and_extract_archive(
url, download_root=root, filename=filename, md5=md5)
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [f"Prefix of data: \t{self.data_prefix['root']}"]
return body
@DATASETS.register_module()
class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
Dataset.
Args:
data_root (str): The root directory of the MNIST Dataset.
split (str, optional): The dataset split, supports "train" and "test".
Default to "train".
metainfo (dict, optional): Meta information for dataset, such as
categories information. Defaults to None.
download (bool): Whether to download the dataset if not exists.
Defaults to True.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
# train images and labels
train_list = [
['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
]
# test images and labels
test_list = [
['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
]
METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
def get_int(b: bytes) -> int:
"""Convert bytes to int."""
return int(codecs.encode(b, 'hex'), 16)
def read_sn3_pascalvincent_tensor(path: str,
strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
read_sn3_pascalvincent_tensor.typemap = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')
}
# read
with open_maybe_compressed_file(path) as f:
data = f.read()
# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
assert nd >= 1 and nd <= 3
assert ty >= 8 and ty <= 14
m = read_sn3_pascalvincent_tensor.typemap[ty]
s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
def read_label_file(path: str) -> torch.Tensor:
"""Read labels from SN3 label file."""
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8)
assert (x.ndimension() == 1)
return x.long()
def read_image_file(path: str) -> torch.Tensor:
"""Read images from SN3 image file."""
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8)
assert (x.ndimension() == 3)
return x
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class MultiLabelDataset(BaseDataset):
"""Multi-label Dataset.
This dataset support annotation file in `OpenMMLab 2.0 style annotation
format`.
The annotation format is shown as follows.
.. code-block:: none
{
"metainfo":
{
"classes":['A', 'B', 'C'....]
},
"data_list":
[
{
"img_path": "test_img1.jpg",
'gt_label': [0, 1],
},
{
"img_path": "test_img2.jpg",
'gt_label': [2],
},
]
....
}
Args:
ann_file (str): Annotation file path.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str | dict): Prefix for training data. Defaults to ''.
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Defaults to None which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=False``. Defaults to False.
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use categories information in ``metainfo`` argument,
annotation file or the class attribute ``METAINFO``.
Defaults to None.
"""
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category ids by index.
Args:
idx (int): Index of data.
Returns:
cat_ids (List[int]): Image categories of specified index.
"""
return self.get_data_info(idx)['gt_label']
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from os import PathLike
from typing import Optional, Sequence
import mmengine
from mmcv.transforms import Compose
from mmengine.fileio import get_file_backend
from .builder import DATASETS
def expanduser(path):
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path
def isabs(uri):
return osp.isabs(uri) or ('://' in uri)
@DATASETS.register_module()
class MultiTaskDataset:
"""Custom dataset for multi-task dataset.
To use the dataset, please generate and provide an annotation file in the
below format:
.. code-block:: json
{
"metainfo": {
"tasks":
[
'gender'
'wear'
]
},
"data_list": [
{
"img_path": "a.jpg",
gt_label:{
"gender": 0,
"wear": [1, 0, 1, 0]
}
},
{
"img_path": "b.jpg",
gt_label:{
"gender": 1,
"wear": [1, 0, 1, 0]
}
}
]
}
Assume we put our dataset in the ``data/mydataset`` folder in the
repository and organize it as the below format: ::
mmpretrain/
└── data
└── mydataset
├── annotation
│   ├── train.json
│   ├── test.json
│   └── val.json
├── train
│   ├── a.jpg
│   └── ...
├── test
│   ├── b.jpg
│   └── ...
└── val
├── c.jpg
└── ...
We can use the below config to build datasets:
.. code:: python
>>> from mmpretrain.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="annotation/train.json",
... data_root="data/mydataset",
... # The `img_path` field in the train annotation file is relative
... # to the `train` folder.
... data_prefix='train',
... )
>>> train_dataset = build_dataset(train_cfg)
Or we can put all files in the same folder: ::
mmpretrain/
└── data
└── mydataset
├── train.json
├── test.json
├── val.json
├── a.jpg
├── b.jpg
├── c.jpg
└── ...
And we can use the below config to build datasets:
.. code:: python
>>> from mmpretrain.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="train.json",
... data_root="data/mydataset",
... # the `data_prefix` is not required since all paths are
... # relative to the `data_root`.
... )
>>> train_dataset = build_dataset(train_cfg)
Args:
ann_file (str): The annotation file path. It can be either absolute
path or relative path to the ``data_root``.
metainfo (dict, optional): The extra meta information. It should be
a dict with the same format as the ``"metainfo"`` field in the
annotation file. Defaults to None.
data_root (str, optional): The root path of the data directory. It's
the prefix of the ``data_prefix`` and the ``ann_file``. And it can
be a remote path like "s3://openmmlab/xxx/". Defaults to None.
data_prefix (str, optional): The base folder relative to the
``data_root`` for the ``"img_path"`` field in the annotation file.
Defaults to None.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in
:mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple.
test_mode (bool): in train mode or test mode. Defaults to False.
"""
METAINFO = dict()
def __init__(self,
ann_file: str,
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: Optional[str] = None,
pipeline: Sequence = (),
test_mode: bool = False):
self.data_root = expanduser(data_root)
# Inference the file client
if self.data_root is not None:
self.file_backend = get_file_backend(uri=self.data_root)
else:
self.file_backend = None
self.ann_file = self._join_root(expanduser(ann_file))
self.data_prefix = self._join_root(data_prefix)
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
self.data_list = self.load_data_list(self.ann_file, metainfo)
def _join_root(self, path):
"""Join ``self.data_root`` with the specified path.
If the path is an absolute path, just return the path. And if the
path is None, return ``self.data_root``.
Examples:
>>> self.data_root = 'a/b/c'
>>> self._join_root('d/e/')
'a/b/c/d/e'
>>> self._join_root('https://openmmlab.com')
'https://openmmlab.com'
>>> self._join_root(None)
'a/b/c'
"""
if path is None:
return self.data_root
if isabs(path):
return path
joined_path = self.file_backend.join_path(self.data_root, path)
return joined_path
@classmethod
def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
in_metainfo (dict): Meta information dict.
Returns:
dict: Parsed meta information.
"""
# `cls.METAINFO` will be overwritten by in_meta
metainfo = copy.deepcopy(cls.METAINFO)
if in_metainfo is None:
return metainfo
metainfo.update(in_metainfo)
return metainfo
def load_data_list(self, ann_file, metainfo_override=None):
"""Load annotations from an annotation file.
Args:
ann_file (str): Absolute annotation file path if ``self.root=None``
or relative path if ``self.root=/path/to/data/``.
Returns:
list[dict]: A list of annotation.
"""
annotations = mmengine.load(ann_file)
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '
f'should be a dict, but got {type(annotations)}!')
if 'data_list' not in annotations:
raise ValueError('The annotation file must have the `data_list` '
'field.')
metainfo = annotations.get('metainfo', {})
raw_data_list = annotations['data_list']
# Set meta information.
assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
f'annotation file should be a dict, but got {type(metainfo)}'
if metainfo_override is not None:
assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
f'argument should be a dict, but got {type(metainfo_override)}'
metainfo.update(metainfo_override)
self._metainfo = self._get_meta_info(metainfo)
data_list = []
for i, raw_data in enumerate(raw_data_list):
try:
data_list.append(self.parse_data_info(raw_data))
except AssertionError as e:
raise RuntimeError(
f'The format check fails during parse the item {i} of '
f'the annotation file with error: {e}')
return data_list
def parse_data_info(self, raw_data):
"""Parse raw annotation to target format.
This method will return a dict which contains the data information of a
sample.
Args:
raw_data (dict): Raw data information load from ``ann_file``
Returns:
dict: Parsed annotation.
"""
assert isinstance(raw_data, dict), \
f'The item should be a dict, but got {type(raw_data)}'
assert 'img_path' in raw_data, \
"The item doesn't have `img_path` field."
data = dict(
img_path=self._join_root(raw_data['img_path']),
gt_label=raw_data['gt_label'],
)
return data
@property
def metainfo(self) -> dict:
"""Get meta information of dataset.
Returns:
dict: meta information collected from ``cls.METAINFO``,
annotation file and metainfo argument during instantiation.
"""
return copy.deepcopy(self._metainfo)
def prepare_data(self, idx):
"""Get data processed by ``self.pipeline``.
Args:
idx (int): The index of ``data_info``.
Returns:
Any: Depends on ``self.pipeline``.
"""
results = copy.deepcopy(self.data_list[idx])
return self.pipeline(results)
def __len__(self):
"""Get the length of the whole dataset.
Returns:
int: The length of filtered dataset.
"""
return len(self.data_list)
def __getitem__(self, idx):
"""Get the idx-th image and data information of dataset after
``self.pipeline``.
Args:
idx (int): The index of of the data.
Returns:
dict: The idx-th image and data information after
``self.pipeline``.
"""
return self.prepare_data(idx)
def __repr__(self):
"""Print the basic information of the dataset.
Returns:
str: Formatted string.
"""
head = 'Dataset ' + self.__class__.__name__
body = [f'Number of samples: \t{self.__len__()}']
if self.data_root is not None:
body.append(f'Root location: \t{self.data_root}')
body.append(f'Annotation file: \t{self.ann_file}')
if self.data_prefix is not None:
body.append(f'Prefix of images: \t{self.data_prefix}')
# -------------------- extra repr --------------------
tasks = self.metainfo['tasks']
body.append(f'For {len(tasks)} tasks')
for task in tasks:
body.append(f' {task} ')
# ----------------------------------------------------
if len(self.pipeline.transforms) > 0:
body.append('With transforms:')
for t in self.pipeline.transforms:
body.append(f' {t}')
lines = [head] + [' ' * 4 + line for line in body]
return '\n'.join(lines)
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import List
from mmengine.fileio import get_file_backend, list_from_file
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class NLVR2(BaseDataset):
"""COCO Caption dataset."""
def load_data_list(self) -> List[dict]:
"""Load data list."""
data_list = []
img_prefix = self.data_prefix['img_path']
file_backend = get_file_backend(img_prefix)
examples = list_from_file(self.ann_file)
for example in examples:
example = json.loads(example)
prefix = example['identifier'].rsplit('-', 1)[0]
train_data = {}
train_data['text'] = example['sentence']
train_data['gt_label'] = {'True': 1, 'False': 0}[example['label']]
train_data['img_path'] = [
file_backend.join_path(img_prefix, prefix + f'-img{i}.png')
for i in range(2)
]
data_list.append(train_data)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from pycocotools.coco import COCO
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class NoCaps(BaseDataset):
"""NoCaps dataset.
Args:
data_root (str): The root directory for ``data_prefix`` and
``ann_file``..
ann_file (str): Annotation file path.
data_prefix (dict): Prefix for data field. Defaults to
``dict(img_path='')``.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def load_data_list(self) -> List[dict]:
"""Load data list."""
img_prefix = self.data_prefix['img_path']
with mmengine.get_local_path(self.ann_file) as ann_file:
coco = COCO(ann_file)
file_backend = get_file_backend(img_prefix)
data_list = []
for ann in coco.anns.values():
image_id = ann['image_id']
image_path = file_backend.join_path(
img_prefix, coco.imgs[image_id]['file_name'])
data_info = {
'image_id': image_id,
'img_path': image_path,
'gt_caption': None
}
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class OCRVQA(BaseDataset):
"""OCR-VQA dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str): Annotation file path for training and validation.
split (str): 'train', 'val' or 'test'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self, data_root: str, data_prefix: str, ann_file: str,
split: str, **kwarg):
assert split in ['train', 'val', 'test'], \
'`split` must be train, val or test'
self.split = split
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
split_dict = {1: 'train', 2: 'val', 3: 'test'}
annotations = mmengine.load(self.ann_file)
# ann example
# "761183272": {
# "imageURL": \
# "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg",
# "questions": [
# "Who wrote this book?",
# "What is the title of this book?",
# "What is the genre of this book?",
# "Is this a games related book?",
# "What is the year printed on this calendar?"],
# "answers": [
# "Sandra Boynton",
# "Mom's Family Wall Calendar 2016",
# "Calendars",
# "No",
# "2016"],
# "title": "Mom's Family Wall Calendar 2016",
# "authorName": "Sandra Boynton",
# "genre": "Calendars",
# "split": 1
# },
data_list = []
for key, ann in annotations.items():
if self.split != split_dict[ann['split']]:
continue
extension = osp.splitext(ann['imageURL'])[1]
if extension not in ['.jpg', '.png']:
continue
img_path = mmengine.join_path(self.data_prefix['img_path'],
key + extension)
for question, answer in zip(ann['questions'], ann['answers']):
data_info = {}
data_info['img_path'] = img_path
data_info['question'] = question
data_info['gt_answer'] = answer
data_info['gt_answer_weight'] = [1.0]
data_info['imageURL'] = ann['imageURL']
data_info['title'] = ann['title']
data_info['authorName'] = ann['authorName']
data_info['genre'] = ann['genre']
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine import get_file_backend, list_from_file
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import OxfordIIITPet_CATEGORIES
@DATASETS.register_module()
class OxfordIIITPet(BaseDataset):
"""The Oxford-IIIT Pets Dataset.
Support the `Oxford-IIIT Pets Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ Dataset.
After downloading and decompression, the dataset directory structure is as follows.
Oxford-IIIT_Pets dataset directory: ::
Oxford-IIIT_Pets
├── images
│ ├── Abyssinian_1.jpg
│ ├── Abyssinian_2.jpg
│ └── ...
├── annotations
│ ├── trainval.txt
│ ├── test.txt
│ ├── list.txt
│ └── ...
└── ....
Args:
data_root (str): The root directory for Oxford-IIIT Pets dataset.
split (str, optional): The dataset split, supports "trainval" and "test".
Default to "trainval".
Examples:
>>> from mmpretrain.datasets import OxfordIIITPet
>>> train_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='trainval')
>>> train_dataset
Dataset OxfordIIITPet
Number of samples: 3680
Number of categories: 37
Root of dataset: data/Oxford-IIIT_Pets
>>> test_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='test')
>>> test_dataset
Dataset OxfordIIITPet
Number of samples: 3669
Number of categories: 37
Root of dataset: data/Oxford-IIIT_Pets
""" # noqa: E501
METAINFO = {'classes': OxfordIIITPet_CATEGORIES}
def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
splits = ['trainval', 'test']
assert split in splits, \
f"The split must be one of {splits}, but get '{split}'"
self.split = split
self.backend = get_file_backend(data_root, enable_singleton=True)
if split == 'trainval':
ann_file = self.backend.join_path('annotations', 'trainval.txt')
else:
ann_file = self.backend.join_path('annotations', 'test.txt')
data_prefix = 'images'
test_mode = split == 'test'
super(OxfordIIITPet, self).__init__(
ann_file=ann_file,
data_root=data_root,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)
def load_data_list(self):
"""Load images and ground truth labels."""
pairs = list_from_file(self.ann_file)
data_list = []
for pair in pairs:
img_name, class_id, _, _ = pair.split()
img_name = f'{img_name}.jpg'
img_path = self.backend.join_path(self.img_prefix, img_name)
gt_label = int(class_id) - 1
info = dict(img_path=img_path, gt_label=gt_label)
data_list.append(info)
return data_list
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [
f'Root of dataset: \t{self.data_root}',
]
return body
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
from mmpretrain.registry import DATASETS
from .categories import PLACES205_CATEGORIES
from .custom import CustomDataset
@DATASETS.register_module()
class Places205(CustomDataset):
"""`Places205 <http://places.csail.mit.edu/downloadData.html>`_ Dataset.
Args:
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str | dict): Prefix for training data. Defaults
to ''.
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
**kwargs: Other keyword arguments in :class:`CustomDataset` and
:class:`BaseDataset`.
"""
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
METAINFO = {'classes': PLACES205_CATEGORIES}
def __init__(self,
data_root: str = '',
data_prefix: Union[str, dict] = '',
ann_file: str = '',
metainfo: Optional[dict] = None,
**kwargs):
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
metainfo=metainfo,
**kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
import mmengine
import numpy as np
from mmengine.dataset import BaseDataset
from pycocotools.coco import COCO
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class RefCOCO(BaseDataset):
"""RefCOCO dataset.
RefCOCO is a popular dataset used for the task of visual grounding.
Here are the steps for accessing and utilizing the
RefCOCO dataset.
You can access the RefCOCO dataset from the official source:
https://github.com/lichengunc/refer
The RefCOCO dataset is organized in a structured format: ::
FeaturesDict({
'coco_annotations': Sequence({
'area': int64,
'bbox': BBoxFeature(shape=(4,), dtype=float32),
'id': int64,
'label': int64,
}),
'image': Image(shape=(None, None, 3), dtype=uint8),
'image/id': int64,
'objects': Sequence({
'area': int64,
'bbox': BBoxFeature(shape=(4,), dtype=float32),
'gt_box_index': int64,
'id': int64,
'label': int64,
'refexp': Sequence({
'raw': Text(shape=(), dtype=string),
'refexp_id': int64,
}),
}),
})
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root,
ann_file,
data_prefix,
split_file,
split='train',
**kwargs):
self.split_file = split_file
self.split = split
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def load_data_list(self) -> List[dict]:
"""Load data list."""
with mmengine.get_local_path(self.ann_file) as ann_file:
coco = COCO(ann_file)
splits = mmengine.load(self.split_file, file_format='pkl')
img_prefix = self.data_prefix['img_path']
data_list = []
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for refer in splits:
if refer['split'] != self.split:
continue
ann = coco.anns[refer['ann_id']]
img = coco.imgs[ann['image_id']]
sentences = refer['sentences']
bbox = np.array(ann['bbox'], dtype=np.float32)
bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY
for sent in sentences:
data_info = {
'img_path': join_path(img_prefix, img['file_name']),
'image_id': ann['image_id'],
'ann_id': ann['id'],
'text': sent['sent'],
'gt_bboxes': bbox[None, :],
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from .repeat_aug import RepeatAugSampler
from .sequential import SequentialSampler
__all__ = ['RepeatAugSampler', 'SequentialSampler']
import math
from typing import Iterator, Optional, Sized
import torch
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
from torch.utils.data import Sampler
from mmpretrain.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class RepeatAugSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset for
distributed, with repeated augmentation. It ensures that different each
augmented version of a sample will be visible to a different process (GPU).
Heavily based on torch.utils.data.DistributedSampler.
This sampler was taken from
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
Used in
Copyright (c) 2015-present, Facebook, Inc.
Args:
dataset (Sized): The dataset.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
num_repeats (int): The repeat times of every sample. Defaults to 3.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
"""
def __init__(self,
dataset: Sized,
shuffle: bool = True,
num_repeats: int = 3,
seed: Optional[int] = None):
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.shuffle = shuffle
if not self.shuffle and is_main_process():
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.warning('The RepeatAugSampler always picks a '
'fixed part of data if `shuffle=False`.')
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.num_repeats = num_repeats
# The number of repeated samples in the rank
self.num_samples = math.ceil(
len(self.dataset) * num_repeats / world_size)
# The total number of repeated samples in all ranks.
self.total_size = self.num_samples * world_size
# The number of selected samples in the rank
self.num_selected_samples = math.ceil(len(self.dataset) / world_size)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices = [x for x in indices for _ in range(self.num_repeats)]
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample per rank
indices = indices[self.rank:self.total_size:self.world_size]
assert len(indices) == self.num_samples
# return up to num selected samples
return iter(indices[:self.num_selected_samples])
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_selected_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterator
import torch
from mmengine.dataset import DefaultSampler
from mmpretrain.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class SequentialSampler(DefaultSampler):
"""Sequential sampler which supports different subsample policy.
Args:
dataset (Sized): The dataset.
round_up (bool): Whether to add extra samples to make the number of
samples evenly divisible by the world size. Defaults to True.
subsample_type (str): The method to subsample data on different rank.
Supported type:
- ``'default'``: Original torch behavior. Sample the examples one
by one for each GPU in terms. For instance, 8 examples on 2 GPUs,
GPU0: [0,2,4,8], GPU1: [1,3,5,7]
- ``'sequential'``: Subsample all examples to n chunk sequntially.
For instance, 8 examples on 2 GPUs,
GPU0: [0,1,2,3], GPU1: [4,5,6,7]
"""
def __init__(self, subsample_type: str = 'default', **kwargs) -> None:
super().__init__(shuffle=False, **kwargs)
if subsample_type not in ['default', 'sequential']:
raise ValueError(f'Unsupported subsample typer "{subsample_type}",'
' please choose from ["default", "sequential"]')
self.subsample_type = subsample_type
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
# subsample
if self.subsample_type == 'default':
indices = indices[self.rank:self.total_size:self.world_size]
elif self.subsample_type == 'sequential':
num_samples_per_rank = self.total_size // self.world_size
indices = indices[self.rank *
num_samples_per_rank:(self.rank + 1) *
num_samples_per_rank]
return iter(indices)
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, List, Sequence
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class ScienceQA(BaseDataset):
"""ScienceQA dataset.
This dataset is used to load the multimodal data of ScienceQA dataset.
Args:
data_root (str): The root directory for ``data_prefix`` and
``ann_file``.
split (str): The split of dataset. Options: ``train``, ``val``,
``test``, ``trainval``, ``minival``, and ``minitest``.
split_file (str): The split file of dataset, which contains the
ids of data samples in the split.
ann_file (str): Annotation file path.
image_only (bool): Whether only to load data with image. Defaults to
False.
data_prefix (dict): Prefix for data field. Defaults to
``dict(img_path='')``.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
split: str,
split_file: str,
ann_file: str,
image_only: bool = False,
data_prefix: dict = dict(img_path=''),
pipeline: Sequence[Callable] = (),
**kwargs):
assert split in [
'train', 'val', 'test', 'trainval', 'minival', 'minitest'
], f'Invalid split {split}'
self.split = split
self.split_file = os.path.join(data_root, split_file)
self.image_only = image_only
super().__init__(
data_root=data_root,
ann_file=ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
**kwargs)
def load_data_list(self) -> List[dict]:
"""Load data list."""
img_prefix = self.data_prefix['img_path']
annotations = mmengine.load(self.ann_file)
current_data_split = mmengine.load(self.split_file)[self.split] # noqa
file_backend = get_file_backend(img_prefix)
data_list = []
for data_id in current_data_split:
ann = annotations[data_id]
if self.image_only and ann['image'] is None:
continue
data_info = {
'image_id':
data_id,
'question':
ann['question'],
'choices':
ann['choices'],
'gt_answer':
ann['answer'],
'hint':
ann['hint'],
'image_name':
ann['image'],
'task':
ann['task'],
'grade':
ann['grade'],
'subject':
ann['subject'],
'topic':
ann['topic'],
'category':
ann['category'],
'skill':
ann['skill'],
'lecture':
ann['lecture'],
'solution':
ann['solution'],
'split':
ann['split'],
'img_path':
file_backend.join_path(img_prefix, data_id, ann['image'])
if ann['image'] is not None else None,
'has_image':
True if ann['image'] is not None else False,
}
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mat4py
from mmengine import get_file_backend
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import STANFORDCARS_CATEGORIES
@DATASETS.register_module()
class StanfordCars(BaseDataset):
"""The Stanford Cars Dataset.
Support the `Stanford Cars Dataset <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset.
The official website provides two ways to organize the dataset.
Therefore, after downloading and decompression, the dataset directory structure is as follows.
Stanford Cars dataset directory: ::
Stanford_Cars
├── car_ims
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── ...
└── cars_annos.mat
or ::
Stanford_Cars
├── cars_train
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── ...
├── cars_test
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── ...
└── devkit
├── cars_meta.mat
├── cars_train_annos.mat
├── cars_test_annos.mat
├── cars_test_annoswithlabels.mat
├── eval_train.m
└── train_perfect_preds.txt
Args:
data_root (str): The root directory for Stanford Cars dataset.
split (str, optional): The dataset split, supports "train"
and "test". Default to "train".
Examples:
>>> from mmpretrain.datasets import StanfordCars
>>> train_dataset = StanfordCars(data_root='data/Stanford_Cars', split='train')
>>> train_dataset
Dataset StanfordCars
Number of samples: 8144
Number of categories: 196
Root of dataset: data/Stanford_Cars
>>> test_dataset = StanfordCars(data_root='data/Stanford_Cars', split='test')
>>> test_dataset
Dataset StanfordCars
Number of samples: 8041
Number of categories: 196
Root of dataset: data/Stanford_Cars
""" # noqa: E501
METAINFO = {'classes': STANFORDCARS_CATEGORIES}
def __init__(self, data_root: str, split: str = 'train', **kwargs):
splits = ['train', 'test']
assert split in splits, \
f"The split must be one of {splits}, but get '{split}'"
self.split = split
test_mode = split == 'test'
self.backend = get_file_backend(data_root, enable_singleton=True)
anno_file_path = self.backend.join_path(data_root, 'cars_annos.mat')
if self.backend.exists(anno_file_path):
ann_file = 'cars_annos.mat'
data_prefix = ''
else:
if test_mode:
ann_file = self.backend.join_path(
'devkit', 'cars_test_annos_withlabels.mat')
data_prefix = 'cars_test'
else:
ann_file = self.backend.join_path('devkit',
'cars_train_annos.mat')
data_prefix = 'cars_train'
if not self.backend.exists(
self.backend.join_path(data_root, ann_file)):
doc_url = 'https://mmpretrain.readthedocs.io/en/latest/api/datasets.html#stanfordcars' # noqa: E501
raise RuntimeError(
f'The dataset is incorrectly organized, please \
refer to {doc_url} and reorganize your folders.')
super(StanfordCars, self).__init__(
ann_file=ann_file,
data_root=data_root,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)
def load_data_list(self):
data = mat4py.loadmat(self.ann_file)['annotations']
data_list = []
if 'test' in data.keys():
# first way
img_paths, labels, test = data['relative_im_path'], data[
'class'], data['test']
num = len(img_paths)
assert num == len(labels) == len(test), 'get error ann file'
for i in range(num):
if not self.test_mode and test[i] == 1:
continue
if self.test_mode and test[i] == 0:
continue
img_path = self.backend.join_path(self.img_prefix,
img_paths[i])
gt_label = labels[i] - 1
info = dict(img_path=img_path, gt_label=gt_label)
data_list.append(info)
else:
# second way
img_names, labels = data['fname'], data['class']
num = len(img_names)
assert num == len(labels), 'get error ann file'
for i in range(num):
img_path = self.backend.join_path(self.img_prefix,
img_names[i])
gt_label = labels[i] - 1
info = dict(img_path=img_path, gt_label=gt_label)
data_list.append(info)
return data_list
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [
f'Root of dataset: \t{self.data_root}',
]
return body
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine import get_file_backend, list_from_file
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import SUN397_CATEGORIES
@DATASETS.register_module()
class SUN397(BaseDataset):
"""The SUN397 Dataset.
Support the `SUN397 Dataset <https://vision.princeton.edu/projects/2010/SUN/>`_ Dataset.
After downloading and decompression, the dataset directory structure is as follows.
SUN397 dataset directory: ::
SUN397
├── SUN397
│ ├── a
│ │ ├── abbey
│ | | ├── sun_aaalbzqrimafwbiv.jpg
│ | | └── ...
│ │ ├── airplane_cabin
│ | | ├── sun_aadqdkqaslqqoblu.jpg
│ | | └── ...
│ | └── ...
│ ├── b
│ │ └── ...
│ ├── c
│ │ └── ...
│ └── ...
└── Partitions
├── ClassName.txt
├── Training_01.txt
├── Testing_01.txt
└── ...
Args:
data_root (str): The root directory for Stanford Cars dataset.
split (str, optional): The dataset split, supports "train" and "test".
Default to "train".
Examples:
>>> from mmpretrain.datasets import SUN397
>>> train_dataset = SUN397(data_root='data/SUN397', split='train')
>>> train_dataset
Dataset SUN397
Number of samples: 19850
Number of categories: 397
Root of dataset: data/SUN397
>>> test_dataset = SUN397(data_root='data/SUN397', split='test')
>>> test_dataset
Dataset SUN397
Number of samples: 19850
Number of categories: 397
Root of dataset: data/SUN397
**Note that some images are not a jpg file although the name ends with ".jpg".
The backend of SUN397 should be "pillow" as below to read these images properly,**
.. code-block:: python
pipeline = [
dict(type='LoadImageFromFile', imdecode_backend='pillow'),
dict(type='RandomResizedCrop', scale=224),
dict(type='PackInputs')
]
""" # noqa: E501
METAINFO = {'classes': SUN397_CATEGORIES}
def __init__(self, data_root: str, split: str = 'train', **kwargs):
splits = ['train', 'test']
assert split in splits, \
f"The split must be one of {splits}, but get '{split}'"
self.split = split
self.backend = get_file_backend(data_root, enable_singleton=True)
if split == 'train':
ann_file = self.backend.join_path('Partitions', 'Training_01.txt')
else:
ann_file = self.backend.join_path('Partitions', 'Testing_01.txt')
data_prefix = 'SUN397'
test_mode = split == 'test'
super(SUN397, self).__init__(
ann_file=ann_file,
data_root=data_root,
test_mode=test_mode,
data_prefix=data_prefix,
**kwargs)
def load_data_list(self):
pairs = list_from_file(self.ann_file)
data_list = []
for pair in pairs:
img_path = self.backend.join_path(self.img_prefix, pair[1:])
items = pair.split('/')
class_name = '_'.join(items[2:-1])
gt_label = self.METAINFO['classes'].index(class_name)
info = dict(img_path=img_path, gt_label=gt_label)
data_list.append(info)
return data_list
def __getitem__(self, idx: int) -> dict:
try:
return super().__getitem__(idx)
except AttributeError:
raise RuntimeError(
'Some images in the SUN397 dataset are not a jpg file '
'although the name ends with ".jpg". The backend of SUN397 '
'should be "pillow" to read these images properly.')
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [
f'Root of dataset: \t{self.data_root}',
]
return body
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class TextVQA(BaseDataset):
"""TextVQA dataset.
val image:
https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
test image:
https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
val json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
test json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
folder structure:
data/textvqa
├── annotations
│ ├── TextVQA_0.5.1_test.json
│ └── TextVQA_0.5.1_val.json
└── images
├── test_images
└── train_images
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)['data']
data_list = []
for ann in annotations:
# ann example
# {
# 'question': 'what is the brand of...is camera?',
# 'image_id': '003a8ae2ef43b901',
# 'image_classes': [
# 'Cassette deck', 'Printer', ...
# ],
# 'flickr_original_url': 'https://farm2.static...04a6_o.jpg',
# 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg',
# 'image_width': 1024,
# 'image_height': 664,
# 'answers': [
# 'nous les gosses',
# 'dakota',
# 'clos culombu',
# 'dakota digital' ...
# ],
# 'question_tokens':
# ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'],
# 'question_id': 34602,
# 'set_name': 'val'
# }
data_info = dict(question=ann['question'])
data_info['question_id'] = ann['question_id']
data_info['image_id'] = ann['image_id']
img_path = mmengine.join_path(self.data_prefix['img_path'],
ann['image_id'] + '.jpg')
data_info['img_path'] = img_path
data_info['question_id'] = ann['question_id']
if 'answers' in ann:
answers = [item for item in ann.pop('answers')]
count = Counter(answers)
answer_weight = [i / len(answers) for i in count.values()]
data_info['gt_answer'] = list(count.keys())
data_info['gt_answer_weight'] = answer_weight
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.transforms import (CenterCrop, LoadImageFromFile, Normalize,
RandomFlip, RandomGrayscale, RandomResize, Resize)
from mmpretrain.registry import TRANSFORMS
from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
Brightness, ColorTransform, Contrast, Cutout,
Equalize, GaussianBlur, Invert, Posterize,
RandAugment, Rotate, Sharpness, Shear, Solarize,
SolarizeAdd, Translate)
from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
PILToNumpy, Transpose)
from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
ColorJitter, EfficientNetCenterCrop,
EfficientNetRandomCrop, Lighting,
MAERandomResizedCrop, RandomCrop, RandomErasing,
RandomResizedCrop,
RandomResizedCropAndInterpolationWithTwoPic,
RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
from .utils import get_transform_idx, remove_transform
from .wrappers import ApplyToList, MultiView
for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
RandomGrayscale, RandomResize, Resize):
TRANSFORMS.register_module(module=t)
__all__ = [
'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop',
'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
'PackInputs', 'Albumentations', 'EfficientNetRandomCrop',
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator',
'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
'remove_transform', 'MAERandomResizedCrop'
]
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from copy import deepcopy
from math import ceil
from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform, Compose, RandomChoice
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_list_of, is_seq_of
from PIL import Image, ImageFilter
from mmpretrain.registry import TRANSFORMS
def merge_hparams(policy: dict, hparams: dict) -> dict:
"""Merge hyperparameters into policy config.
Only merge partial hyperparameters required of the policy.
Args:
policy (dict): Original policy config dict.
hparams (dict): Hyperparameters need to be merged.
Returns:
dict: Policy config dict after adding ``hparams``.
"""
policy = deepcopy(policy)
op = TRANSFORMS.get(policy['type'])
assert op is not None, f'Invalid policy type "{policy["type"]}".'
op_args = inspect.getfullargspec(op.__init__).args
for key, value in hparams.items():
if key in op_args and key not in policy:
policy[key] = value
return policy
@TRANSFORMS.register_module()
class AutoAugment(RandomChoice):
"""Auto augmentation.
This data augmentation is proposed in `AutoAugment: Learning Augmentation
Policies from Data <https://arxiv.org/abs/1805.09501>`_.
Args:
policies (str | list[list[dict]]): The policies of auto augmentation.
If string, use preset policies collection like "imagenet". If list,
Each item is a sub policies, composed by several augmentation
policy dicts. When AutoAugment is called, a random sub policies in
``policies`` will be selected to augment images.
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to ``dict(pad_val=128)``.
.. admonition:: Available preset policies
- ``"imagenet"``: Policy for ImageNet, come from
`DeepVoltaire/AutoAugment`_
.. _DeepVoltaire/AutoAugment: https://github.com/DeepVoltaire/AutoAugment
"""
def __init__(self,
policies: Union[str, List[List[dict]]],
hparams: dict = dict(pad_val=128)):
if isinstance(policies, str):
assert policies in AUTOAUG_POLICIES, 'Invalid policies, ' \
f'please choose from {list(AUTOAUG_POLICIES.keys())}.'
policies = AUTOAUG_POLICIES[policies]
self.hparams = hparams
self.policies = [[merge_hparams(t, hparams) for t in sub]
for sub in policies]
transforms = [[TRANSFORMS.build(t) for t in sub] for sub in policies]
super().__init__(transforms=transforms)
def __repr__(self) -> str:
policies_str = ''
for sub in self.policies:
policies_str += '\n ' + ', \t'.join([t['type'] for t in sub])
repr_str = self.__class__.__name__
repr_str += f'(policies:{policies_str}\n)'
return repr_str
@TRANSFORMS.register_module()
class RandAugment(BaseTransform):
r"""Random augmentation.
This data augmentation is proposed in `RandAugment: Practical automated
data augmentation with a reduced search space
<https://arxiv.org/abs/1909.13719>`_.
Args:
policies (str | list[dict]): The policies of random augmentation.
If string, use preset policies collection like "timm_increasing".
If list, each item is one specific augmentation policy dict.
The policy dict shall should have these keys:
- ``type`` (str), The type of augmentation.
- ``magnitude_range`` (Sequence[number], optional): For those
augmentation have magnitude, you need to specify the magnitude
level mapping range. For example, assume ``total_level`` is 10,
``magnitude_level=3`` specify magnitude is 3 if
``magnitude_range=(0, 10)`` while specify magnitude is 7 if
``magnitude_range=(10, 0)``.
- other keyword arguments of the augmentation.
num_policies (int): Number of policies to select from policies each
time.
magnitude_level (int | float): Magnitude level for all the augmentation
selected.
magnitude_std (Number | str): Deviation of magnitude noise applied.
- If positive number, the magnitude obeys normal distribution
:math:`\mathcal{N}(magnitude_level, magnitude_std)`.
- If 0 or negative number, magnitude remains unchanged.
- If str "inf", the magnitude obeys uniform distribution
:math:`Uniform(min, magnitude)`.
total_level (int | float): Total level for the magnitude. Defaults to
10.
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to ``dict(pad_val=128)``.
.. admonition:: Available preset policies
- ``"timm_increasing"``: The ``_RAND_INCREASING_TRANSFORMS`` policy
from `timm`_
.. _timm: https://github.com/rwightman/pytorch-image-models
Examples:
To use "timm-increasing" policies collection, select two policies every
time, and magnitude_level of every policy is 6 (total is 10 by default)
>>> import numpy as np
>>> from mmpretrain.datasets import RandAugment
>>> transform = RandAugment(
... policies='timm_increasing',
... num_policies=2,
... magnitude_level=6,
... )
>>> data = {'img': np.random.randint(0, 256, (224, 224, 3))}
>>> results = transform(data)
>>> print(results['img'].shape)
(224, 224, 3)
If you want the ``magnitude_level`` randomly changes every time, you
can use ``magnitude_std`` to specify the random distribution. For
example, a normal distribution :math:`\mathcal{N}(6, 0.5)`.
>>> transform = RandAugment(
... policies='timm_increasing',
... num_policies=2,
... magnitude_level=6,
... magnitude_std=0.5,
... )
You can also use your own policies:
>>> policies = [
... dict(type='AutoContrast'),
... dict(type='Rotate', magnitude_range=(0, 30)),
... dict(type='ColorTransform', magnitude_range=(0, 0.9)),
... ]
>>> transform = RandAugment(
... policies=policies,
... num_policies=2,
... magnitude_level=6
... )
Note:
``magnitude_std`` will introduce some randomness to policy, modified by
https://github.com/rwightman/pytorch-image-models.
When magnitude_std=0, we calculate the magnitude as follows:
.. math::
\text{magnitude} = \frac{\text{magnitude_level}}
{\text{totallevel}} \times (\text{val2} - \text{val1})
+ \text{val1}
"""
def __init__(self,
policies: Union[str, List[dict]],
num_policies: int,
magnitude_level: int,
magnitude_std: Union[Number, str] = 0.,
total_level: int = 10,
hparams: dict = dict(pad_val=128)):
if isinstance(policies, str):
assert policies in RANDAUG_POLICIES, 'Invalid policies, ' \
f'please choose from {list(RANDAUG_POLICIES.keys())}.'
policies = RANDAUG_POLICIES[policies]
assert is_list_of(policies, dict), 'policies must be a list of dict.'
assert isinstance(magnitude_std, (Number, str)), \
'`magnitude_std` must be of number or str type, ' \
f'got {type(magnitude_std)} instead.'
if isinstance(magnitude_std, str):
assert magnitude_std == 'inf', \
'`magnitude_std` must be of number or "inf", ' \
f'got "{magnitude_std}" instead.'
assert num_policies > 0, 'num_policies must be greater than 0.'
assert magnitude_level >= 0, 'magnitude_level must be no less than 0.'
assert total_level > 0, 'total_level must be greater than 0.'
self.num_policies = num_policies
self.magnitude_level = magnitude_level
self.magnitude_std = magnitude_std
self.total_level = total_level
self.hparams = hparams
self.policies = []
self.transforms = []
randaug_cfg = dict(
magnitude_level=magnitude_level,
total_level=total_level,
magnitude_std=magnitude_std)
for policy in policies:
self._check_policy(policy)
policy = merge_hparams(policy, hparams)
policy.pop('magnitude_key', None) # For backward compatibility
if 'magnitude_range' in policy:
policy.update(randaug_cfg)
self.policies.append(policy)
self.transforms.append(TRANSFORMS.build(policy))
def __iter__(self):
"""Iterate all transforms."""
return iter(self.transforms)
def _check_policy(self, policy):
"""Check whether the sub-policy dict is available."""
assert isinstance(policy, dict) and 'type' in policy, \
'Each policy must be a dict with key "type".'
type_name = policy['type']
if 'magnitude_range' in policy:
magnitude_range = policy['magnitude_range']
assert is_seq_of(magnitude_range, Number), \
f'`magnitude_range` of RandAugment policy {type_name} ' \
'should be a sequence with two numbers.'
@cache_randomness
def random_policy_indices(self) -> np.ndarray:
"""Return the random chosen transform indices."""
indices = np.arange(len(self.policies))
return np.random.choice(indices, size=self.num_policies).tolist()
def transform(self, results: dict) -> Optional[dict]:
"""Randomly choose a sub-policy to apply."""
chosen_policies = [
self.transforms[i] for i in self.random_policy_indices()
]
sub_pipeline = Compose(chosen_policies)
return sub_pipeline(results)
def __repr__(self) -> str:
policies_str = ''
for policy in self.policies:
policies_str += '\n ' + f'{policy["type"]}'
if 'magnitude_range' in policy:
val1, val2 = policy['magnitude_range']
policies_str += f' ({val1}, {val2})'
repr_str = self.__class__.__name__
repr_str += f'(num_policies={self.num_policies}, '
repr_str += f'magnitude_level={self.magnitude_level}, '
repr_str += f'total_level={self.total_level}, '
repr_str += f'policies:{policies_str}\n)'
return repr_str
class BaseAugTransform(BaseTransform):
r"""The base class of augmentation transform for RandAugment.
This class provides several common attributions and methods to support the
magnitude level mapping and magnitude level randomness in
:class:`RandAugment`.
Args:
magnitude_level (int | float): Magnitude level.
magnitude_range (Sequence[number], optional): For augmentation have
magnitude argument, maybe "magnitude", "angle" or other, you can
specify the magnitude level mapping range to generate the magnitude
argument. For example, assume ``total_level`` is 10,
``magnitude_level=3`` specify magnitude is 3 if
``magnitude_range=(0, 10)`` while specify magnitude is 7 if
``magnitude_range=(10, 0)``. Defaults to None.
magnitude_std (Number | str): Deviation of magnitude noise applied.
- If positive number, the magnitude obeys normal distribution
:math:`\mathcal{N}(magnitude, magnitude_std)`.
- If 0 or negative number, magnitude remains unchanged.
- If str "inf", the magnitude obeys uniform distribution
:math:`Uniform(min, magnitude)`.
Defaults to 0.
total_level (int | float): Total level for the magnitude. Defaults to
10.
prob (float): The probability for performing transformation therefore
should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.
"""
def __init__(self,
magnitude_level: int = 10,
magnitude_range: Tuple[float, float] = None,
magnitude_std: Union[str, float] = 0.,
total_level: int = 10,
prob: float = 0.5,
random_negative_prob: float = 0.5):
self.magnitude_level = magnitude_level
self.magnitude_range = magnitude_range
self.magnitude_std = magnitude_std
self.total_level = total_level
self.prob = prob
self.random_negative_prob = random_negative_prob
@cache_randomness
def random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.prob
@cache_randomness
def random_magnitude(self):
"""Randomly generate magnitude."""
magnitude = self.magnitude_level
# if magnitude_std is positive number or 'inf', move
# magnitude_value randomly.
if self.magnitude_std == 'inf':
magnitude = np.random.uniform(0, magnitude)
elif self.magnitude_std > 0:
magnitude = np.random.normal(magnitude, self.magnitude_std)
magnitude = np.clip(magnitude, 0, self.total_level)
val1, val2 = self.magnitude_range
magnitude = (magnitude / self.total_level) * (val2 - val1) + val1
return magnitude
@cache_randomness
def random_negative(self, value):
"""Randomly negative the value."""
if np.random.rand() < self.random_negative_prob:
return -value
else:
return value
def extra_repr(self):
"""Extra repr string when auto-generating magnitude is enabled."""
if self.magnitude_range is not None:
repr_str = f', magnitude_level={self.magnitude_level}, '
repr_str += f'magnitude_range={self.magnitude_range}, '
repr_str += f'magnitude_std={self.magnitude_std}, '
repr_str += f'total_level={self.total_level}, '
return repr_str
else:
return ''
@TRANSFORMS.register_module()
class Shear(BaseAugTransform):
"""Shear images.
Args:
magnitude (int | float | None): The magnitude used for shear. If None,
generate from ``magnitude_range``, see :class:`BaseAugTransform`.
Defaults to None.
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
prob (float): The probability for performing shear therefore should be
in range [0, 1]. Defaults to 0.5.
direction (str): The shearing direction. Options are 'horizontal' and
'vertical'. Defaults to 'horizontal'.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.5.
interpolation (str): Interpolation method. Options are 'nearest',
'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'bicubic'.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
magnitude: Union[int, float, None] = None,
pad_val: Union[int, Sequence[int]] = 128,
prob: float = 0.5,
direction: str = 'horizontal',
random_negative_prob: float = 0.5,
interpolation: str = 'bicubic',
**kwargs):
super().__init__(
prob=prob, random_negative_prob=random_negative_prob, **kwargs)
assert (magnitude is None) ^ (self.magnitude_range is None), \
'Please specify only one of `magnitude` and `magnitude_range`.'
self.magnitude = magnitude
if isinstance(pad_val, Sequence):
self.pad_val = tuple(pad_val)
else:
self.pad_val = pad_val
assert direction in ('horizontal', 'vertical'), 'direction must be ' \
f'either "horizontal" or "vertical", got "{direction}" instead.'
self.direction = direction
self.interpolation = interpolation
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.magnitude is not None:
magnitude = self.random_negative(self.magnitude)
else:
magnitude = self.random_negative(self.random_magnitude())
img = results['img']
img_sheared = mmcv.imshear(
img,
magnitude,
direction=self.direction,
border_value=self.pad_val,
interpolation=self.interpolation)
results['img'] = img_sheared.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob}, '
repr_str += f'direction={self.direction}, '
repr_str += f'random_negative_prob={self.random_negative_prob}, '
repr_str += f'interpolation={self.interpolation}{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class Translate(BaseAugTransform):
"""Translate images.
Args:
magnitude (int | float | None): The magnitude used for translate. Note
that the offset is calculated by magnitude * size in the
corresponding direction. With a magnitude of 1, the whole image
will be moved out of the range. If None, generate from
``magnitude_range``, see :class:`BaseAugTransform`.
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
prob (float): The probability for performing translate therefore should
be in range [0, 1]. Defaults to 0.5.
direction (str): The translating direction. Options are 'horizontal'
and 'vertical'. Defaults to 'horizontal'.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.5.
interpolation (str): Interpolation method. Options are 'nearest',
'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
magnitude: Union[int, float, None] = None,
pad_val: Union[int, Sequence[int]] = 128,
prob: float = 0.5,
direction: str = 'horizontal',
random_negative_prob: float = 0.5,
interpolation: str = 'nearest',
**kwargs):
super().__init__(
prob=prob, random_negative_prob=random_negative_prob, **kwargs)
assert (magnitude is None) ^ (self.magnitude_range is None), \
'Please specify only one of `magnitude` and `magnitude_range`.'
self.magnitude = magnitude
if isinstance(pad_val, Sequence):
self.pad_val = tuple(pad_val)
else:
self.pad_val = pad_val
assert direction in ('horizontal', 'vertical'), 'direction must be ' \
f'either "horizontal" or "vertical", got "{direction}" instead.'
self.direction = direction
self.interpolation = interpolation
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.magnitude is not None:
magnitude = self.random_negative(self.magnitude)
else:
magnitude = self.random_negative(self.random_magnitude())
img = results['img']
height, width = img.shape[:2]
if self.direction == 'horizontal':
offset = magnitude * width
else:
offset = magnitude * height
img_translated = mmcv.imtranslate(
img,
offset,
direction=self.direction,
border_value=self.pad_val,
interpolation=self.interpolation)
results['img'] = img_translated.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob}, '
repr_str += f'direction={self.direction}, '
repr_str += f'random_negative_prob={self.random_negative_prob}, '
repr_str += f'interpolation={self.interpolation}{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class Rotate(BaseAugTransform):
"""Rotate images.
Args:
angle (float, optional): The angle used for rotate. Positive values
stand for clockwise rotation. If None, generate from
``magnitude_range``, see :class:`BaseAugTransform`.
Defaults to None.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If None, the center of the image will be used.
Defaults to None.
scale (float): Isotropic scale factor. Defaults to 1.0.
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
prob (float): The probability for performing rotate therefore should be
in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the angle
negative, which should be in range [0,1]. Defaults to 0.5.
interpolation (str): Interpolation method. Options are 'nearest',
'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
angle: Optional[float] = None,
center: Optional[Tuple[float]] = None,
scale: float = 1.0,
pad_val: Union[int, Sequence[int]] = 128,
prob: float = 0.5,
random_negative_prob: float = 0.5,
interpolation: str = 'nearest',
**kwargs):
super().__init__(
prob=prob, random_negative_prob=random_negative_prob, **kwargs)
assert (angle is None) ^ (self.magnitude_range is None), \
'Please specify only one of `angle` and `magnitude_range`.'
self.angle = angle
self.center = center
self.scale = scale
if isinstance(pad_val, Sequence):
self.pad_val = tuple(pad_val)
else:
self.pad_val = pad_val
self.interpolation = interpolation
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.angle is not None:
angle = self.random_negative(self.angle)
else:
angle = self.random_negative(self.random_magnitude())
img = results['img']
img_rotated = mmcv.imrotate(
img,
angle,
center=self.center,
scale=self.scale,
border_value=self.pad_val,
interpolation=self.interpolation)
results['img'] = img_rotated.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(angle={self.angle}, '
repr_str += f'center={self.center}, '
repr_str += f'scale={self.scale}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob}, '
repr_str += f'random_negative_prob={self.random_negative_prob}, '
repr_str += f'interpolation={self.interpolation}{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class AutoContrast(BaseAugTransform):
"""Auto adjust image contrast.
Args:
prob (float): The probability for performing auto contrast
therefore should be in range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self, prob: float = 0.5, **kwargs):
super().__init__(prob=prob, **kwargs)
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
img = results['img']
img_contrasted = mmcv.auto_contrast(img)
results['img'] = img_contrasted.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob})'
return repr_str
@TRANSFORMS.register_module()
class Invert(BaseAugTransform):
"""Invert images.
Args:
prob (float): The probability for performing invert therefore should
be in range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self, prob: float = 0.5, **kwargs):
super().__init__(prob=prob, **kwargs)
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
img = results['img']
img_inverted = mmcv.iminvert(img)
results['img'] = img_inverted.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob})'
return repr_str
@TRANSFORMS.register_module()
class Equalize(BaseAugTransform):
"""Equalize the image histogram.
Args:
prob (float): The probability for performing equalize therefore should
be in range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self, prob: float = 0.5, **kwargs):
super().__init__(prob=prob, **kwargs)
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
img = results['img']
img_equalized = mmcv.imequalize(img)
results['img'] = img_equalized.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob})'
return repr_str
@TRANSFORMS.register_module()
class Solarize(BaseAugTransform):
"""Solarize images (invert all pixel values above a threshold).
Args:
thr (int | float | None): The threshold above which the pixels value
will be inverted. If None, generate from ``magnitude_range``,
see :class:`BaseAugTransform`. Defaults to None.
prob (float): The probability for solarizing therefore should be in
range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
thr: Union[int, float, None] = None,
prob: float = 0.5,
**kwargs):
super().__init__(prob=prob, random_negative_prob=0., **kwargs)
assert (thr is None) ^ (self.magnitude_range is None), \
'Please specify only one of `thr` and `magnitude_range`.'
self.thr = thr
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.thr is not None:
thr = self.thr
else:
thr = self.random_magnitude()
img = results['img']
img_solarized = mmcv.solarize(img, thr=thr)
results['img'] = img_solarized.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(thr={self.thr}, '
repr_str += f'prob={self.prob}{self.extra_repr()}))'
return repr_str
@TRANSFORMS.register_module()
class SolarizeAdd(BaseAugTransform):
"""SolarizeAdd images (add a certain value to pixels below a threshold).
Args:
magnitude (int | float | None): The value to be added to pixels below
the thr. If None, generate from ``magnitude_range``, see
:class:`BaseAugTransform`. Defaults to None.
thr (int | float): The threshold below which the pixels value will be
adjusted.
prob (float): The probability for solarizing therefore should be in
range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
magnitude: Union[int, float, None] = None,
thr: Union[int, float] = 128,
prob: float = 0.5,
**kwargs):
super().__init__(prob=prob, random_negative_prob=0., **kwargs)
assert (magnitude is None) ^ (self.magnitude_range is None), \
'Please specify only one of `magnitude` and `magnitude_range`.'
self.magnitude = magnitude
assert isinstance(thr, (int, float)), 'The thr type must '\
f'be int or float, but got {type(thr)} instead.'
self.thr = thr
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.magnitude is not None:
magnitude = self.magnitude
else:
magnitude = self.random_magnitude()
img = results['img']
img_solarized = np.where(img < self.thr,
np.minimum(img + magnitude, 255), img)
results['img'] = img_solarized.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'thr={self.thr}, '
repr_str += f'prob={self.prob}{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class Posterize(BaseAugTransform):
"""Posterize images (reduce the number of bits for each color channel).
Args:
bits (int, optional): Number of bits for each pixel in the output img,
which should be less or equal to 8. If None, generate from
``magnitude_range``, see :class:`BaseAugTransform`.
Defaults to None.
prob (float): The probability for posterizing therefore should be in
range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
bits: Optional[int] = None,
prob: float = 0.5,
**kwargs):
super().__init__(prob=prob, random_negative_prob=0., **kwargs)
assert (bits is None) ^ (self.magnitude_range is None), \
'Please specify only one of `bits` and `magnitude_range`.'
if bits is not None:
assert bits <= 8, \
f'The bits must be less than 8, got {bits} instead.'
self.bits = bits
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.bits is not None:
bits = self.bits
else:
bits = self.random_magnitude()
# To align timm version, we need to round up to integer here.
bits = ceil(bits)
img = results['img']
img_posterized = mmcv.posterize(img, bits=bits)
results['img'] = img_posterized.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(bits={self.bits}, '
repr_str += f'prob={self.prob}{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class Contrast(BaseAugTransform):
"""Adjust images contrast.
Args:
magnitude (int | float | None): The magnitude used for adjusting
contrast. A positive magnitude would enhance the contrast and
a negative magnitude would make the image grayer. A magnitude=0
gives the origin img. If None, generate from ``magnitude_range``,
see :class:`BaseAugTransform`. Defaults to None.
prob (float): The probability for performing contrast adjusting
therefore should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.5.
"""
def __init__(self,
magnitude: Union[int, float, None] = None,
prob: float = 0.5,
random_negative_prob: float = 0.5,
**kwargs):
super().__init__(
prob=prob, random_negative_prob=random_negative_prob, **kwargs)
assert (magnitude is None) ^ (self.magnitude_range is None), \
'Please specify only one of `magnitude` and `magnitude_range`.'
self.magnitude = magnitude
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.magnitude is not None:
magnitude = self.random_negative(self.magnitude)
else:
magnitude = self.random_negative(self.random_magnitude())
img = results['img']
img_contrasted = mmcv.adjust_contrast(img, factor=1 + magnitude)
results['img'] = img_contrasted.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'prob={self.prob}, '
repr_str += f'random_negative_prob={self.random_negative_prob}'
repr_str += f'{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class ColorTransform(BaseAugTransform):
"""Adjust images color balance.
Args:
magnitude (int | float | None): The magnitude used for color transform.
A positive magnitude would enhance the color and a negative
magnitude would make the image grayer. A magnitude=0 gives the
origin img. If None, generate from ``magnitude_range``, see
:class:`BaseAugTransform`. Defaults to None.
prob (float): The probability for performing ColorTransform therefore
should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
magnitude: Union[int, float, None] = None,
prob: float = 0.5,
random_negative_prob: float = 0.5,
**kwargs):
super().__init__(
prob=prob, random_negative_prob=random_negative_prob, **kwargs)
assert (magnitude is None) ^ (self.magnitude_range is None), \
'Please specify only one of `magnitude` and `magnitude_range`.'
self.magnitude = magnitude
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.magnitude is not None:
magnitude = self.random_negative(self.magnitude)
else:
magnitude = self.random_negative(self.random_magnitude())
img = results['img']
img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude)
results['img'] = img_color_adjusted.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'prob={self.prob}, '
repr_str += f'random_negative_prob={self.random_negative_prob}'
repr_str += f'{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class Brightness(BaseAugTransform):
"""Adjust images brightness.
Args:
magnitude (int | float | None): The magnitude used for adjusting
brightness. A positive magnitude would enhance the brightness and a
negative magnitude would make the image darker. A magnitude=0 gives
the origin img. If None, generate from ``magnitude_range``, see
:class:`BaseAugTransform`. Defaults to None.
prob (float): The probability for performing brightness adjusting
therefore should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
magnitude: Union[int, float, None] = None,
prob: float = 0.5,
random_negative_prob: float = 0.5,
**kwargs):
super().__init__(
prob=prob, random_negative_prob=random_negative_prob, **kwargs)
assert (magnitude is None) ^ (self.magnitude_range is None), \
'Please specify only one of `magnitude` and `magnitude_range`.'
self.magnitude = magnitude
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.magnitude is not None:
magnitude = self.random_negative(self.magnitude)
else:
magnitude = self.random_negative(self.random_magnitude())
img = results['img']
img_brightened = mmcv.adjust_brightness(img, factor=1 + magnitude)
results['img'] = img_brightened.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'prob={self.prob}, '
repr_str += f'random_negative_prob={self.random_negative_prob}'
repr_str += f'{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class Sharpness(BaseAugTransform):
"""Adjust images sharpness.
Args:
magnitude (int | float | None): The magnitude used for adjusting
sharpness. A positive magnitude would enhance the sharpness and a
negative magnitude would make the image bulr. A magnitude=0 gives
the origin img. If None, generate from ``magnitude_range``, see
:class:`BaseAugTransform`. Defaults to None.
prob (float): The probability for performing sharpness adjusting
therefore should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
magnitude: Union[int, float, None] = None,
prob: float = 0.5,
random_negative_prob: float = 0.5,
**kwargs):
super().__init__(
prob=prob, random_negative_prob=random_negative_prob, **kwargs)
assert (magnitude is None) ^ (self.magnitude_range is None), \
'Please specify only one of `magnitude` and `magnitude_range`.'
self.magnitude = magnitude
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.magnitude is not None:
magnitude = self.random_negative(self.magnitude)
else:
magnitude = self.random_negative(self.random_magnitude())
img = results['img']
img_sharpened = mmcv.adjust_sharpness(img, factor=1 + magnitude)
results['img'] = img_sharpened.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'prob={self.prob}, '
repr_str += f'random_negative_prob={self.random_negative_prob}'
repr_str += f'{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class Cutout(BaseAugTransform):
"""Cutout images.
Args:
shape (int | tuple(int) | None): Expected cutout shape (h, w).
If given as a single value, the value will be used for both h and
w. If None, generate from ``magnitude_range``, see
:class:`BaseAugTransform`. Defaults to None.
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If it is a sequence, it must have the same length with the image
channels. Defaults to 128.
prob (float): The probability for performing cutout therefore should
be in range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
shape: Union[int, Tuple[int], None] = None,
pad_val: Union[int, Sequence[int]] = 128,
prob: float = 0.5,
**kwargs):
super().__init__(prob=prob, random_negative_prob=0., **kwargs)
assert (shape is None) ^ (self.magnitude_range is None), \
'Please specify only one of `shape` and `magnitude_range`.'
self.shape = shape
if isinstance(pad_val, Sequence):
self.pad_val = tuple(pad_val)
else:
self.pad_val = pad_val
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.shape is not None:
shape = self.shape
else:
shape = int(self.random_magnitude())
img = results['img']
img_cutout = mmcv.cutout(img, shape, pad_val=self.pad_val)
results['img'] = img_cutout.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(shape={self.shape}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob}{self.extra_repr()})'
return repr_str
@TRANSFORMS.register_module()
class GaussianBlur(BaseAugTransform):
"""Gaussian blur images.
Args:
radius (int, float, optional): The blur radius. If None, generate from
``magnitude_range``, see :class:`BaseAugTransform`.
Defaults to None.
prob (float): The probability for posterizing therefore should be in
range [0, 1]. Defaults to 0.5.
**kwargs: Other keyword arguments of :class:`BaseAugTransform`.
"""
def __init__(self,
radius: Union[int, float, None] = None,
prob: float = 0.5,
**kwargs):
super().__init__(prob=prob, random_negative_prob=0., **kwargs)
assert (radius is None) ^ (self.magnitude_range is None), \
'Please specify only one of `radius` and `magnitude_range`.'
self.radius = radius
def transform(self, results):
"""Apply transform to results."""
if self.random_disable():
return results
if self.radius is not None:
radius = self.radius
else:
radius = self.random_magnitude()
img = results['img']
pil_img = Image.fromarray(img)
pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=radius))
results['img'] = np.array(pil_img, dtype=img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(radius={self.radius}, '
repr_str += f'prob={self.prob}{self.extra_repr()})'
return repr_str
# yapf: disable
# flake8: noqa
AUTOAUG_POLICIES = {
# Policy for ImageNet, refers to
# https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py
'imagenet': [
[dict(type='Posterize', bits=4, prob=0.4), dict(type='Rotate', angle=30., prob=0.6)],
[dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)],
[dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)],
[dict(type='Posterize', bits=5, prob=0.6), dict(type='Posterize', bits=5, prob=0.6)],
[dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)],
[dict(type='Equalize', prob=0.4), dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)],
[dict(type='Solarize', thr=256 / 9 * 6, prob=0.6), dict(type='Equalize', prob=0.6)],
[dict(type='Posterize', bits=6, prob=0.8), dict(type='Equalize', prob=1.)],
[dict(type='Rotate', angle=10., prob=0.2), dict(type='Solarize', thr=256 / 9, prob=0.6)],
[dict(type='Equalize', prob=0.6), dict(type='Posterize', bits=5, prob=0.4)],
[dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0., prob=0.4)],
[dict(type='Rotate', angle=30., prob=0.4), dict(type='Equalize', prob=0.6)],
[dict(type='Equalize', prob=0.0), dict(type='Equalize', prob=0.8)],
[dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)],
[dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)],
[dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0.2, prob=1.)],
[dict(type='ColorTransform', magnitude=0.8, prob=0.8), dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)],
[dict(type='Sharpness', magnitude=0.7, prob=0.4), dict(type='Invert', prob=0.6)],
[dict(type='Shear', magnitude=0.3 / 9 * 5, prob=0.6, direction='horizontal'), dict(type='Equalize', prob=1.)],
[dict(type='ColorTransform', magnitude=0., prob=0.4), dict(type='Equalize', prob=0.6)],
[dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)],
[dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)],
[dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)],
[dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)],
[dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)],
],
}
RANDAUG_POLICIES = {
# Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models
'timm_increasing': [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Invert'),
dict(type='Rotate', magnitude_range=(0, 30)),
dict(type='Posterize', magnitude_range=(4, 0)),
dict(type='Solarize', magnitude_range=(256, 0)),
dict(type='SolarizeAdd', magnitude_range=(0, 110)),
dict(type='ColorTransform', magnitude_range=(0, 0.9)),
dict(type='Contrast', magnitude_range=(0, 0.9)),
dict(type='Brightness', magnitude_range=(0, 0.9)),
dict(type='Sharpness', magnitude_range=(0, 0.9)),
dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'),
dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'),
dict(type='Translate', magnitude_range=(0, 0.45), direction='horizontal'),
dict(type='Translate', magnitude_range=(0, 0.45), direction='vertical'),
],
'simple_increasing': [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Rotate', magnitude_range=(0, 30)),
dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'),
dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'),
],
}
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from collections.abc import Sequence
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from mmcv.transforms import BaseTransform
from mmengine.utils import is_str
from PIL import Image
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample, MultiTaskDataSample
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(
f'Type {type(data)} cannot be converted to tensor.'
'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
'`Sequence`, `int` and `float`')
@TRANSFORMS.register_module()
class PackInputs(BaseTransform):
"""Pack the inputs data.
**Required Keys:**
- ``input_key``
- ``*algorithm_keys``
- ``*meta_keys``
**Deleted Keys:**
All other keys in the dict.
**Added Keys:**
- inputs (:obj:`torch.Tensor`): The forward data of models.
- data_samples (:obj:`~mmpretrain.structures.DataSample`): The
annotation info of the sample.
Args:
input_key (str): The key of element to feed into the model forwarding.
Defaults to 'img'.
algorithm_keys (Sequence[str]): The keys of custom elements to be used
in the algorithm. Defaults to an empty tuple.
meta_keys (Sequence[str]): The keys of meta information to be saved in
the data sample. Defaults to :attr:`PackInputs.DEFAULT_META_KEYS`.
.. admonition:: Default algorithm keys
Besides the specified ``algorithm_keys``, we will set some default keys
into the output data sample and do some formatting. Therefore, you
don't need to set these keys in the ``algorithm_keys``.
- ``gt_label``: The ground-truth label. The value will be converted
into a 1-D tensor.
- ``gt_score``: The ground-truth score. The value will be converted
into a 1-D tensor.
- ``mask``: The mask for some self-supervise tasks. The value will
be converted into a tensor.
.. admonition:: Default meta keys
- ``sample_idx``: The id of the image sample.
- ``img_path``: The path to the image file.
- ``ori_shape``: The original shape of the image as a tuple (H, W).
- ``img_shape``: The shape of the image after the pipeline as a
tuple (H, W).
- ``scale_factor``: The scale factor between the resized image and
the original image.
- ``flip``: A boolean indicating if image flip transform was used.
- ``flip_direction``: The flipping direction.
"""
DEFAULT_META_KEYS = ('sample_idx', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')
def __init__(self,
input_key='img',
algorithm_keys=(),
meta_keys=DEFAULT_META_KEYS):
self.input_key = input_key
self.algorithm_keys = algorithm_keys
self.meta_keys = meta_keys
@staticmethod
def format_input(input_):
if isinstance(input_, list):
return [PackInputs.format_input(item) for item in input_]
elif isinstance(input_, np.ndarray):
if input_.ndim == 2: # For grayscale image.
input_ = np.expand_dims(input_, -1)
if input_.ndim == 3 and not input_.flags.c_contiguous:
input_ = np.ascontiguousarray(input_.transpose(2, 0, 1))
input_ = to_tensor(input_)
elif input_.ndim == 3:
# convert to tensor first to accelerate, see
# https://github.com/open-mmlab/mmdetection/pull/9533
input_ = to_tensor(input_).permute(2, 0, 1).contiguous()
else:
# convert input with other shape to tensor without permute,
# like video input (num_crops, C, T, H, W).
input_ = to_tensor(input_)
elif isinstance(input_, Image.Image):
input_ = F.pil_to_tensor(input_)
elif not isinstance(input_, torch.Tensor):
raise TypeError(f'Unsupported input type {type(input_)}.')
return input_
def transform(self, results: dict) -> dict:
"""Method to pack the input data."""
packed_results = dict()
if self.input_key in results:
input_ = results[self.input_key]
packed_results['inputs'] = self.format_input(input_)
data_sample = DataSample()
# Set default keys
if 'gt_label' in results:
data_sample.set_gt_label(results['gt_label'])
if 'gt_score' in results:
data_sample.set_gt_score(results['gt_score'])
if 'mask' in results:
data_sample.set_mask(results['mask'])
# Set custom algorithm keys
for key in self.algorithm_keys:
if key in results:
data_sample.set_field(results[key], key)
# Set meta keys
for key in self.meta_keys:
if key in results:
data_sample.set_field(results[key], key, field_type='metainfo')
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f"(input_key='{self.input_key}', "
repr_str += f'algorithm_keys={self.algorithm_keys}, '
repr_str += f'meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class PackMultiTaskInputs(BaseTransform):
"""Convert all image labels of multi-task dataset to a dict of tensor.
Args:
multi_task_fields (Sequence[str]):
input_key (str):
task_handlers (dict):
"""
def __init__(self,
multi_task_fields,
input_key='img',
task_handlers=dict()):
self.multi_task_fields = multi_task_fields
self.input_key = input_key
self.task_handlers = defaultdict(PackInputs)
for task_name, task_handler in task_handlers.items():
self.task_handlers[task_name] = TRANSFORMS.build(task_handler)
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3},
'img': array([[[ 0, 0, 0])
"""
packed_results = dict()
results = results.copy()
if self.input_key in results:
input_ = results[self.input_key]
packed_results['inputs'] = PackInputs.format_input(input_)
task_results = defaultdict(dict)
for field in self.multi_task_fields:
if field in results:
value = results.pop(field)
for k, v in value.items():
task_results[k].update({field: v})
data_sample = MultiTaskDataSample()
for task_name, task_result in task_results.items():
task_handler = self.task_handlers[task_name]
task_pack_result = task_handler({**results, **task_result})
data_sample.set_field(task_pack_result['data_samples'], task_name)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self):
repr = self.__class__.__name__
task_handlers = ', '.join(
f"'{name}': {handler.__class__.__name__}"
for name, handler in self.task_handlers.items())
repr += f'(multi_task_fields={self.multi_task_fields}, '
repr += f"input_key='{self.input_key}', "
repr += f'task_handlers={{{task_handlers}}})'
return repr
@TRANSFORMS.register_module()
class Transpose(BaseTransform):
"""Transpose numpy array.
**Required Keys:**
- ``*keys``
**Modified Keys:**
- ``*keys``
Args:
keys (List[str]): The fields to convert to tensor.
order (List[int]): The output dimensions order.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def transform(self, results):
"""Method to transpose array."""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL'))
class NumpyToPIL(BaseTransform):
"""Convert the image from OpenCV format to :obj:`PIL.Image.Image`.
**Required Keys:**
- ``img``
**Modified Keys:**
- ``img``
Args:
to_rgb (bool): Whether to convert img to rgb. Defaults to True.
"""
def __init__(self, to_rgb: bool = False) -> None:
self.to_rgb = to_rgb
def transform(self, results: dict) -> dict:
"""Method to convert images to :obj:`PIL.Image.Image`."""
img = results['img']
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img
results['img'] = Image.fromarray(img)
return results
def __repr__(self) -> str:
return self.__class__.__name__ + f'(to_rgb={self.to_rgb})'
@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy'))
class PILToNumpy(BaseTransform):
"""Convert img to :obj:`numpy.ndarray`.
**Required Keys:**
- ``img``
**Modified Keys:**
- ``img``
Args:
to_bgr (bool): Whether to convert img to rgb. Defaults to True.
dtype (str, optional): The dtype of the converted numpy array.
Defaults to None.
"""
def __init__(self, to_bgr: bool = False, dtype=None) -> None:
self.to_bgr = to_bgr
self.dtype = dtype
def transform(self, results: dict) -> dict:
"""Method to convert img to :obj:`numpy.ndarray`."""
img = np.array(results['img'], dtype=self.dtype)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img
results['img'] = img
return results
def __repr__(self) -> str:
return self.__class__.__name__ + \
f'(to_bgr={self.to_bgr}, dtype={self.dtype})'
@TRANSFORMS.register_module()
class Collect(BaseTransform):
"""Collect and only reserve the specified fields.
**Required Keys:**
- ``*keys``
**Deleted Keys:**
All keys except those in the argument ``*keys``.
Args:
keys (Sequence[str]): The keys of the fields to be collected.
"""
def __init__(self, keys):
self.keys = keys
def transform(self, results):
data = {}
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment