Commit 1ac2e802 authored by limm's avatar limm
Browse files

add tools code

parent b6df0d33
Pipeline #2803 canceled with stages
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CHECKPOINT=$4
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
WORK_DIR=$4
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:5}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from copy import deepcopy
import mmengine
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
def parse_args():
parser = argparse.ArgumentParser(
description='MMPreTrain test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument('--out', help='the file to output results.')
parser.add_argument(
'--out-item',
choices=['metrics', 'pred'],
help='To output whether metrics or predictions. '
'Defaults to output predictions.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision test')
parser.add_argument(
'--show-dir',
help='directory where the visualization images will be saved.')
parser.add_argument(
'--show',
action='store_true',
help='whether to display the prediction results in a window.')
parser.add_argument(
'--interval',
type=int,
default=1,
help='visualize per interval samples.')
parser.add_argument(
'--wait-time',
type=float,
default=2,
help='display time of every window. (second)')
parser.add_argument(
'--no-pin-memory',
action='store_true',
help='whether to disable the pin_memory option in dataloaders.')
parser.add_argument(
'--tta',
action='store_true',
help='Whether to enable the Test-Time-Aug (TTA). If the config file '
'has `tta_pipeline` and `tta_model` fields, use them to determine the '
'TTA transforms and how to merge the TTA results. Otherwise, use flip '
'TTA by averaging classification score.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
cfg.launcher = args.launcher
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.load_from = args.checkpoint
# enable automatic-mixed-precision test
if args.amp:
cfg.test_cfg.fp16 = True
# -------------------- visualization --------------------
if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
'VisualizationHook is not set in the `default_hooks` field of ' \
'config. Please set `visualization=dict(type="VisualizationHook")`'
cfg.default_hooks.visualization.enable = True
cfg.default_hooks.visualization.show = args.show
cfg.default_hooks.visualization.wait_time = args.wait_time
cfg.default_hooks.visualization.out_dir = args.show_dir
cfg.default_hooks.visualization.interval = args.interval
# -------------------- TTA related args --------------------
if args.tta:
if 'tta_model' not in cfg:
cfg.tta_model = dict(type='mmpretrain.AverageClsScoreTTA')
if 'tta_pipeline' not in cfg:
test_pipeline = cfg.test_dataloader.dataset.pipeline
cfg.tta_pipeline = deepcopy(test_pipeline)
flip_tta = dict(
type='TestTimeAug',
transforms=[
[
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[test_pipeline[-1]],
])
cfg.tta_pipeline[-1] = flip_tta
cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
# ----------------- Default dataloader args -----------------
default_dataloader_cfg = ConfigDict(
pin_memory=True,
collate_fn=dict(type='default_collate'),
)
def set_default_dataloader_cfg(cfg, field):
if cfg.get(field, None) is None:
return
dataloader_cfg = deepcopy(default_dataloader_cfg)
dataloader_cfg.update(cfg[field])
cfg[field] = dataloader_cfg
if args.no_pin_memory:
cfg[field]['pin_memory'] = False
set_default_dataloader_cfg(cfg, 'test_dataloader')
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
return cfg
def main():
args = parse_args()
if args.out is None and args.out_item is not None:
raise ValueError('Please use `--out` argument to specify the '
'path of the output file before using `--out-item`.')
# load config
cfg = Config.fromfile(args.config)
# merge cli arguments to config
cfg = merge_args(cfg, args)
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
if args.out and args.out_item in ['pred', None]:
runner.test_evaluator.metrics.append(
DumpResults(out_file_path=args.out))
# start testing
metrics = runner.test()
if args.out and args.out_item == 'metrics':
mmengine.dump(metrics, args.out)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
import mmengine
try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
raise ImportError(
'Please run `pip install torchserve torch-model-archiver"` to '
'install required third-party libraries.')
def mmpretrain2torchserve(
config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
force: bool = False,
):
"""Converts mmpretrain model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file:
In MMPretrain config format.
The contents vary for each task repository.
checkpoint_file:
In MMPretrain checkpoint format.
The contents vary for each task repository.
output_folder:
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name:
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version:
Model's version.
force:
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
"""
mmengine.mkdir_or_exist(output_folder)
config = mmengine.Config.fromfile(config_file)
with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')
args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler': f'{Path(__file__).parent}/mmpretrain_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)
def parse_args():
parser = ArgumentParser(
description='Convert mmpretrain models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')
mmpretrain2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.force)
# Copyright (c) OpenMMLab. All rights reserved.
import base64
import os
import mmcv
import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler
import mmpretrain.models
from mmpretrain.apis import (ImageClassificationInferencer,
ImageRetrievalInferencer, get_model)
class MMPreHandler(BaseHandler):
def initialize(self, context):
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest
model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')
model = get_model(self.config_file, checkpoint, self.device)
if isinstance(model, mmpretrain.models.ImageClassifier):
self.inferencer = ImageClassificationInferencer(model)
elif isinstance(model, mmpretrain.models.ImageToImageRetriever):
self.inferencer = ImageRetrievalInferencer(model)
else:
raise NotImplementedError(
f'No available inferencer for {type(model)}')
self.initialized = True
def preprocess(self, data):
images = []
for row in data:
image = row.get('data') or row.get('body')
if isinstance(image, str):
image = base64.b64decode(image)
image = mmcv.imfrombytes(image)
images.append(image)
return images
def inference(self, data, *args, **kwargs):
results = []
for image in data:
results.append(self.inferencer(image)[0])
return results
def postprocess(self, data):
processed_data = []
for result in data:
processed_result = {}
for k, v in result.items():
if isinstance(v, (torch.Tensor, np.ndarray)):
processed_result[k] = v.tolist()
else:
processed_result[k] = v
processed_data.append(processed_result)
return processed_data
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import numpy as np
import requests
from mmpretrain.apis import get_model, inference_model
def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
def main(args):
# Inference single image by native apis.
model = get_model(args.config, args.checkpoint, device=args.device)
model_result = inference_model(model, args.img)
# Inference single image by torchserve engine.
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
with open(args.img, 'rb') as image:
response = requests.post(url, image)
server_result = response.json()
assert np.allclose(model_result['pred_score'], server_result['pred_score'])
print('Test complete, the results of PyTorch and TorchServe are the same.')
if __name__ == '__main__':
args = parse_args()
main(args)
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from copy import deepcopy
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
nargs='?',
type=str,
const='auto',
help='If specify checkpoint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision training')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--no-pin-memory',
action='store_true',
help='whether to disable the pin_memory option in dataloaders.')
parser.add_argument(
'--no-persistent-workers',
action='store_true',
help='whether to disable the persistent_workers option in dataloaders.'
)
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
if args.no_validate:
cfg.val_cfg = None
cfg.val_dataloader = None
cfg.val_evaluator = None
cfg.launcher = args.launcher
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
# resume training
if args.resume == 'auto':
cfg.resume = True
cfg.load_from = None
elif args.resume is not None:
cfg.resume = True
cfg.load_from = args.resume
# enable auto scale learning rate
if args.auto_scale_lr:
cfg.auto_scale_lr.enable = True
# set dataloader args
default_dataloader_cfg = ConfigDict(
pin_memory=True,
persistent_workers=True,
collate_fn=dict(type='default_collate'),
)
if digit_version(TORCH_VERSION) < digit_version('1.8.0'):
default_dataloader_cfg.persistent_workers = False
def set_default_dataloader_cfg(cfg, field):
if cfg.get(field, None) is None:
return
dataloader_cfg = deepcopy(default_dataloader_cfg)
dataloader_cfg.update(cfg[field])
cfg[field] = dataloader_cfg
if args.no_pin_memory:
cfg[field]['pin_memory'] = False
if args.no_persistent_workers:
cfg[field]['persistent_workers'] = False
set_default_dataloader_cfg(cfg, 'train_dataloader')
set_default_dataloader_cfg(cfg, 'val_dataloader')
set_default_dataloader_cfg(cfg, 'test_dataloader')
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
return cfg
def main():
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
# merge cli arguments to config
cfg = merge_args(cfg, args)
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start training
runner.train()
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import sys
import textwrap
from matplotlib import transforms
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmengine.visualization.utils import img_from_canvas
from mmpretrain.datasets.builder import build_dataset
from mmpretrain.structures import DataSample
from mmpretrain.visualization import UniversalVisualizer, create_figure
try:
from matplotlib._tight_bbox import adjust_bbox
except ImportError:
# To be compatible with matplotlib 3.5
from matplotlib.tight_bbox import adjust_bbox
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--output-dir',
'-o',
default=None,
type=str,
help='If there is no display interface, you can save it.')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--phase',
'-p',
default='train',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to visualize, accept "train" "test" and "val".'
' Defaults to "train".')
parser.add_argument(
'--show-number',
'-n',
type=int,
default=sys.maxsize,
help='number of images selected to visualize, must bigger than 0. if '
'the number is bigger than length of dataset, show all the images in '
'dataset; default "sys.maxsize", show all images in dataset')
parser.add_argument(
'--show-interval',
'-i',
type=float,
default=2,
help='the interval of show (s)')
parser.add_argument(
'--mode',
'-m',
default='transformed',
type=str,
choices=['original', 'transformed', 'concat', 'pipeline'],
help='display mode; display original pictures or transformed pictures'
' or comparison pictures. "original" means show images load from disk'
'; "transformed" means to show images after transformed; "concat" '
'means show images stitched by "original" and "output" images. '
'"pipeline" means show all the intermediate images. '
'Defaults to "transformed".')
parser.add_argument(
'--rescale-factor',
'-r',
type=float,
help='(For `mode=original`) Image rescale factor, which is useful if'
'the output is too large or too small.')
parser.add_argument(
'--channel-order',
'-c',
default='BGR',
choices=['BGR', 'RGB'],
help='The channel order of the showing images, could be "BGR" '
'or "RGB", Defaults to "BGR".')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def make_grid(imgs, names):
"""Concat list of pictures into a single big picture, align height here."""
# A large canvas to ensure all text clear.
figure = create_figure(dpi=150, figsize=(16, 9))
# deal with imgs
max_nrows = 1
img_shapes = []
for img in imgs:
if isinstance(img, list):
max_nrows = max(len(img), max_nrows)
img_shapes.append([i.shape[:2] for i in img])
else:
img_shapes.append(img.shape[:2])
gs = figure.add_gridspec(max_nrows, len(imgs))
for i, img in enumerate(imgs):
if isinstance(img, list):
for j in range(len(img)):
subplot = figure.add_subplot(gs[j, i])
subplot.axis(False)
subplot.imshow(img[j])
name = '\n'.join(textwrap.wrap(names[i] + str(j), width=20))
subplot.set_title(
f'{name}\n{img_shapes[i][j]}',
fontsize=15,
family='monospace')
else:
subplot = figure.add_subplot(gs[:, i])
subplot.axis(False)
subplot.imshow(img)
name = '\n'.join(textwrap.wrap(names[i], width=20))
subplot.set_title(
f'{name}\n{img_shapes[i]}', fontsize=15, family='monospace')
# Manage the gap of subplots
figure.tight_layout()
# Remove the white boundary (reserve 0.5 inches at the top to show label)
points = figure.get_tightbbox(
figure.canvas.get_renderer()).get_points() + [[0, 0], [0, 0.5]]
adjust_bbox(figure, transforms.Bbox(points))
return img_from_canvas(figure.canvas)
class InspectCompose(Compose):
"""Compose multiple transforms sequentially.
And record "img" field of all results in one list.
"""
def __init__(self, transforms, intermediate_imgs, visualizer):
super().__init__(transforms=transforms)
self.intermediate_imgs = intermediate_imgs
self.visualizer = visualizer
def __call__(self, data):
if 'img' in data:
self.intermediate_imgs.append({
'name': 'Original',
'img': data['img'].copy()
})
for t in self.transforms:
data = t(data)
if data is None:
return None
if 'img' in data:
img = data['img'].copy()
if 'mask' in data:
tmp_img = img[0] if isinstance(img, list) else img
tmp_img = self.visualizer.add_mask_to_image(
tmp_img,
DataSample().set_mask(data['mask']),
resize=tmp_img.shape[:2])
img = [tmp_img] + img[1:] if isinstance(img,
list) else tmp_img
self.intermediate_imgs.append({
'name': t.__class__.__name__,
'img': img
})
return data
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope('mmpretrain') # Use mmpretrain as default scope.
dataset_cfg = cfg.get(args.phase + '_dataloader').get('dataset')
dataset = build_dataset(dataset_cfg)
# init visualizer
cfg.visualizer.pop('type')
fig_cfg = dict(figsize=(16, 10))
visualizer = UniversalVisualizer(
**cfg.visualizer, fig_show_cfg=fig_cfg, fig_save_cfg=fig_cfg)
visualizer.dataset_meta = dataset.metainfo
# init inspection
intermediate_imgs = []
dataset.pipeline = InspectCompose(dataset.pipeline.transforms,
intermediate_imgs, visualizer)
# init visualization image number
display_number = min(args.show_number, len(dataset))
progress_bar = ProgressBar(display_number)
for i, item in zip(range(display_number), dataset):
rescale_factor = None
if args.mode == 'original':
image = intermediate_imgs[0]['img']
# Only original mode need rescale factor, `make_grid` will use
# matplotlib to manage the size of subplots.
rescale_factor = args.rescale_factor
elif args.mode == 'transformed':
image = make_grid([intermediate_imgs[-1]['img']], ['transformed'])
elif args.mode == 'concat':
ori_image = intermediate_imgs[0]['img']
trans_image = intermediate_imgs[-1]['img']
image = make_grid([ori_image, trans_image],
['original', 'transformed'])
else:
image = make_grid([result['img'] for result in intermediate_imgs],
[result['name'] for result in intermediate_imgs])
intermediate_imgs.clear()
data_sample = item['data_samples'].numpy()
# get filename from dataset or just use index as filename
if hasattr(item['data_samples'], 'img_path'):
filename = osp.basename(item['data_samples'].img_path)
else:
# some dataset have not image path
filename = f'{i}.jpg'
out_file = osp.join(args.output_dir,
filename) if args.output_dir is not None else None
visualizer.visualize_cls(
image if args.channel_order == 'RGB' else image[..., ::-1],
data_sample,
rescale_factor=rescale_factor,
show=not args.not_show,
wait_time=args.show_interval,
name=filename,
out_file=out_file)
progress_bar.update()
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import math
import pkg_resources
from functools import partial
from pathlib import Path
import mmcv
import numpy as np
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.config import Config, DictAction
from mmengine.dataset import default_collate
from mmengine.utils import to_2tuple
from mmengine.utils.dl_utils import is_norm
from mmpretrain import digit_version
from mmpretrain.apis import get_model
from mmpretrain.registry import TRANSFORMS
try:
import pytorch_grad_cam as cam
from pytorch_grad_cam.activations_and_gradients import \
ActivationsAndGradients
from pytorch_grad_cam.utils.image import show_cam_on_image
except ImportError:
raise ImportError('Please run `pip install "grad-cam>=1.3.6"` to install '
'3rd party package pytorch_grad_cam.')
# Alias name
METHOD_MAP = {
'gradcam++': cam.GradCAMPlusPlus,
}
METHOD_MAP.update({
cam_class.__name__.lower(): cam_class
for cam_class in cam.base_cam.BaseCAM.__subclasses__()
})
def parse_args():
parser = argparse.ArgumentParser(description='Visualize CAM')
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--target-layers',
default=[],
nargs='+',
type=str,
help='The target layers to get CAM, if not set, the tool will '
'specify the norm layer in the last block. Backbones '
'implemented by users are recommended to manually specify'
' target layers in commmad statement.')
parser.add_argument(
'--preview-model',
default=False,
action='store_true',
help='To preview all the model layers')
parser.add_argument(
'--method',
default='GradCAM',
help='Type of method to use, supports '
f'{", ".join(list(METHOD_MAP.keys()))}.')
parser.add_argument(
'--target-category',
default=[],
nargs='+',
type=int,
help='The target category to get CAM, default to use result '
'get from given model.')
parser.add_argument(
'--eigen-smooth',
default=False,
action='store_true',
help='Reduce noise by taking the first principle componenet of '
'``cam_weights*activations``')
parser.add_argument(
'--aug-smooth',
default=False,
action='store_true',
help='Wether to use test time augmentation, default not to use')
parser.add_argument(
'--save-path',
type=Path,
help='The path to save visualize cam image, default not to save.')
parser.add_argument('--device', default='cpu', help='Device to use cpu')
parser.add_argument(
'--vit-like',
action='store_true',
help='Whether the network is a ViT-like network.')
parser.add_argument(
'--num-extra-tokens',
type=int,
help='The number of extra tokens in ViT-like backbones. Defaults to'
' use num_extra_tokens of the backbone.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
if args.method.lower() not in METHOD_MAP.keys():
raise ValueError(f'invalid CAM type {args.method},'
f' supports {", ".join(list(METHOD_MAP.keys()))}.')
return args
def reshape_transform(tensor, model, args):
"""Build reshape_transform for `cam.activations_and_grads`, which is
necessary for ViT-like networks."""
# ViT_based_Transformers have an additional clstoken in features
if tensor.ndim == 4:
# For (B, C, H, W)
return tensor
elif tensor.ndim == 3:
if not args.vit_like:
raise ValueError(f"The tensor shape is {tensor.shape}, if it's a "
'vit-like backbone, please specify `--vit-like`.')
# For (B, L, C)
num_extra_tokens = args.num_extra_tokens or getattr(
model.backbone, 'num_extra_tokens', 1)
tensor = tensor[:, num_extra_tokens:, :]
# get heat_map_height and heat_map_width, preset input is a square
heat_map_area = tensor.size()[1]
height, width = to_2tuple(int(math.sqrt(heat_map_area)))
assert height * height == heat_map_area, \
(f"The input feature's length ({heat_map_area+num_extra_tokens}) "
f'minus num-extra-tokens ({num_extra_tokens}) is {heat_map_area},'
' which is not a perfect square number. Please check if you used '
'a wrong num-extra-tokens.')
# (B, L, C) -> (B, H, W, C)
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
# (B, H, W, C) -> (B, C, H, W)
result = result.permute(0, 3, 1, 2)
return result
else:
raise ValueError(f'Unsupported tensor shape {tensor.shape}.')
def init_cam(method, model, target_layers, use_cuda, reshape_transform):
"""Construct the CAM object once, In order to be compatible with
mmpretrain, here we modify the ActivationsAndGradients object."""
GradCAM_Class = METHOD_MAP[method.lower()]
cam = GradCAM_Class(
model=model, target_layers=target_layers, use_cuda=use_cuda)
# Release the original hooks in ActivationsAndGradients to use
# ActivationsAndGradients.
cam.activations_and_grads.release()
cam.activations_and_grads = ActivationsAndGradients(
cam.model, cam.target_layers, reshape_transform)
return cam
def get_layer(layer_str, model):
"""get model layer from given str."""
for name, layer in model.named_modules():
if name == layer_str:
return layer
raise AttributeError(
f'Cannot get the layer "{layer_str}". Please choose from: \n' +
'\n'.join(name for name, _ in model.named_modules()))
def show_cam_grad(grayscale_cam, src_img, title, out_path=None):
"""fuse src_img and grayscale_cam and show or save."""
grayscale_cam = grayscale_cam[0, :]
src_img = np.float32(src_img) / 255
visualization_img = show_cam_on_image(
src_img, grayscale_cam, use_rgb=False)
if out_path:
mmcv.imwrite(visualization_img, str(out_path))
else:
mmcv.imshow(visualization_img, win_name=title)
def get_default_target_layers(model, args):
"""get default target layers from given model, here choose nrom type layer
as default target layer."""
norm_layers = [
(name, layer)
for name, layer in model.backbone.named_modules(prefix='backbone')
if is_norm(layer)
]
if args.vit_like:
# For ViT models, the final classification is done on the class token.
# And the patch tokens and class tokens won't interact each other after
# the final attention layer. Therefore, we need to choose the norm
# layer before the last attention layer.
num_extra_tokens = args.num_extra_tokens or getattr(
model.backbone, 'num_extra_tokens', 1)
# models like swin have no attr 'out_type', set out_type to avg_featmap
out_type = getattr(model.backbone, 'out_type', 'avg_featmap')
if out_type == 'cls_token' or num_extra_tokens > 0:
# Assume the backbone feature is class token.
name, layer = norm_layers[-3]
print('Automatically choose the last norm layer before the '
f'final attention block "{name}" as the target layer.')
return [layer]
# For CNN models, use the last norm layer as the target-layer
name, layer = norm_layers[-1]
print('Automatically choose the last norm layer '
f'"{name}" as the target layer.')
return [layer]
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# build the model from a config file and a checkpoint file
model: nn.Module = get_model(cfg, args.checkpoint, device=args.device)
if args.preview_model:
print(model)
print('\n Please remove `--preview-model` to get the CAM.')
return
# apply transform and perpare data
transforms = Compose(
[TRANSFORMS.build(t) for t in cfg.test_dataloader.dataset.pipeline])
data = transforms({'img_path': args.img})
src_img = copy.deepcopy(data['inputs']).numpy().transpose(1, 2, 0)
data = model.data_preprocessor(default_collate([data]), False)
# build target layers
if args.target_layers:
target_layers = [
get_layer(layer, model) for layer in args.target_layers
]
else:
target_layers = get_default_target_layers(model, args)
# init a cam grad calculator
use_cuda = ('cuda' in args.device)
cam = init_cam(args.method, model, target_layers, use_cuda,
partial(reshape_transform, model=model, args=args))
# warp the target_category with ClassifierOutputTarget in grad_cam>=1.3.7,
# to fix the bug in #654.
targets = None
if args.target_category:
grad_cam_v = pkg_resources.get_distribution('grad_cam').version
if digit_version(grad_cam_v) >= digit_version('1.3.7'):
from pytorch_grad_cam.utils.model_targets import \
ClassifierOutputTarget
targets = [ClassifierOutputTarget(c) for c in args.target_category]
else:
targets = args.target_category
# calculate cam grads and show|save the visualization image
grayscale_cam = cam(
data['inputs'],
targets,
eigen_smooth=args.eigen_smooth,
aug_smooth=args.aug_smooth)
show_cam_grad(
grayscale_cam, src_img, title=args.method, out_path=args.save_path)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import os.path as osp
import re
from pathlib import Path
from unittest.mock import MagicMock
import matplotlib.pyplot as plt
import rich
import torch.nn as nn
from mmengine.config import Config, DictAction
from mmengine.hooks import Hook
from mmengine.model import BaseModel
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
class SimpleModel(BaseModel):
"""simple model that do nothing in train_step."""
def __init__(self):
super(SimpleModel, self).__init__()
self.data_preprocessor = nn.Identity()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, inputs, data_samples, mode='tensor'):
pass
def train_step(self, data, optim_wrapper):
pass
class ParamRecordHook(Hook):
def __init__(self, by_epoch):
super().__init__()
self.by_epoch = by_epoch
self.lr_list = []
self.momentum_list = []
self.wd_list = []
self.task_id = 0
self.progress = Progress(BarColumn(), MofNCompleteColumn(),
TextColumn('{task.description}'))
def before_train(self, runner):
if self.by_epoch:
total = runner.train_loop.max_epochs
self.task_id = self.progress.add_task(
'epochs', start=True, total=total)
else:
total = runner.train_loop.max_iters
self.task_id = self.progress.add_task(
'iters', start=True, total=total)
self.progress.start()
def after_train_epoch(self, runner):
if self.by_epoch:
self.progress.update(self.task_id, advance=1)
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
if not self.by_epoch:
self.progress.update(self.task_id, advance=1)
self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0])
self.momentum_list.append(
runner.optim_wrapper.get_momentum()['momentum'][0])
self.wd_list.append(
runner.optim_wrapper.param_groups[0]['weight_decay'])
def after_train(self, runner):
self.progress.stop()
def parse_args():
parser = argparse.ArgumentParser(
description='Visualize a Dataset Pipeline')
parser.add_argument('config', help='config file path')
parser.add_argument(
'-p',
'--parameter',
type=str,
default='lr',
choices=['lr', 'momentum', 'wd'],
help='The parameter to visualize its change curve, choose from'
'"lr", "wd" and "momentum". Defaults to "lr".')
parser.add_argument(
'-d',
'--dataset-size',
type=int,
help='The size of the dataset. If specify, `build_dataset` will '
'be skipped and use this size as the dataset size.')
parser.add_argument(
'-n',
'--ngpus',
type=int,
default=1,
help='The number of GPUs used in training.')
parser.add_argument(
'-s',
'--save-path',
type=Path,
help='The learning rate curve plot save path')
parser.add_argument(
'--log-level',
default='WARNING',
help='The log level of the handler and logger. Defaults to '
'WARNING.')
parser.add_argument('--title', type=str, help='title of figure')
parser.add_argument(
'--style',
type=str,
default='whitegrid',
help='style of the figure, need `seaborn` package.')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--window-size',
default='12*7',
help='Size of the window to display images, in format of "$W*$H".')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
if args.window_size != '':
assert re.match(r'\d+\*\d+', args.window_size), \
"'window-size' must be in format 'W*H'."
return args
def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True):
"""Plot learning rate vs iter graph."""
try:
import seaborn as sns
sns.set_style(args.style)
except ImportError:
pass
wind_w, wind_h = args.window_size.split('*')
wind_w, wind_h = int(wind_w), int(wind_h)
plt.figure(figsize=(wind_w, wind_h))
ax: plt.Axes = plt.subplot()
ax.plot(lr_list, linewidth=1)
if by_epoch:
ax.xaxis.tick_top()
ax.set_xlabel('Iters')
ax.xaxis.set_label_position('top')
sec_ax = ax.secondary_xaxis(
'bottom',
functions=(lambda x: x / iters_per_epoch,
lambda y: y * iters_per_epoch))
sec_ax.set_xlabel('Epochs')
else:
plt.xlabel('Iters')
plt.ylabel(param_name)
if args.title is None:
plt.title(f'{osp.basename(args.config)} {param_name} curve')
else:
plt.title(args.title)
def simulate_train(data_loader, cfg, by_epoch):
model = SimpleModel()
param_record_hook = ParamRecordHook(by_epoch=by_epoch)
default_hooks = dict(
param_scheduler=cfg.default_hooks['param_scheduler'],
runtime_info=None,
timer=None,
logger=None,
checkpoint=None,
sampler_seed=None,
param_record=param_record_hook)
runner = Runner(
model=model,
work_dir=cfg.work_dir,
train_dataloader=data_loader,
train_cfg=cfg.train_cfg,
log_level=cfg.log_level,
optim_wrapper=cfg.optim_wrapper,
param_scheduler=cfg.param_scheduler,
default_scope=cfg.default_scope,
default_hooks=default_hooks,
visualizer=MagicMock(spec=Visualizer),
custom_hooks=cfg.get('custom_hooks', None))
runner.train()
param_dict = dict(
lr=param_record_hook.lr_list,
momentum=param_record_hook.momentum_list,
wd=param_record_hook.wd_list)
return param_dict
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.log_level = args.log_level
# make sure save_root exists
if args.save_path and not args.save_path.parent.exists():
raise FileNotFoundError(
f'The save path is {args.save_path}, and directory '
f"'{args.save_path.parent}' do not exist.")
# init logger
print('Param_scheduler :')
rich.print_json(json.dumps(cfg.param_scheduler))
# prepare data loader
batch_size = cfg.train_dataloader.batch_size * args.ngpus
if 'by_epoch' in cfg.train_cfg:
by_epoch = cfg.train_cfg.get('by_epoch')
elif 'type' in cfg.train_cfg:
by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop'
else:
raise ValueError('please set `train_cfg`.')
if args.dataset_size is None and by_epoch:
from mmpretrain.datasets import build_dataset
dataset_size = len(build_dataset(cfg.train_dataloader.dataset))
else:
dataset_size = args.dataset_size or batch_size
class FakeDataloader(list):
dataset = MagicMock(metainfo=None)
data_loader = FakeDataloader(range(dataset_size // batch_size))
dataset_info = (
f'\nDataset infos:'
f'\n - Dataset size: {dataset_size}'
f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}'
f'\n - Number of GPUs: {args.ngpus}'
f'\n - Total batch size: {batch_size}')
if by_epoch:
dataset_info += f'\n - Iterations per epoch: {len(data_loader)}'
rich.print(dataset_info + '\n')
# simulation training process
param_dict = simulate_train(data_loader, cfg, by_epoch)
param_list = param_dict[args.parameter]
if args.parameter == 'lr':
param_name = 'Learning Rate'
elif args.parameter == 'momentum':
param_name = 'Momentum'
else:
param_name = 'Weight Decay'
plot_curve(param_list, args, param_name, len(data_loader), by_epoch)
if args.save_path:
plt.savefig(args.save_path)
print(f'\nThe {param_name} graph is saved at {args.save_path}')
if not args.not_show:
plt.show()
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import time
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import rich.progress as progress
import torch
import torch.nn.functional as F
from mmengine.config import Config, DictAction
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.runner import Runner
from mmengine.utils import mkdir_or_exist
from mmpretrain.apis import get_model
from mmpretrain.registry import DATASETS
try:
from sklearn.manifold import TSNE
except ImportError as e:
raise ImportError('Please install `sklearn` to calculate '
'TSNE by `pip install scikit-learn`') from e
def parse_args():
parser = argparse.ArgumentParser(description='t-SNE visualization')
parser.add_argument('config', help='tsne config file path')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--test-cfg',
help='tsne config file path to load config of test dataloader.')
parser.add_argument(
'--vis-stage',
choices=['backbone', 'neck', 'pre_logits'],
help='The visualization stage of the model')
parser.add_argument(
'--class-idx',
nargs='+',
type=int,
help='The categories used to calculate t-SNE.')
parser.add_argument(
'--max-num-class',
type=int,
default=20,
help='The first N categories to apply t-SNE algorithms. '
'Defaults to 20.')
parser.add_argument(
'--max-num-samples',
type=int,
default=100,
help='The maximum number of samples per category. '
'Higher number need longer time to calculate. Defaults to 100.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument('--device', help='Device used for inference')
parser.add_argument(
'--legend',
action='store_true',
help='Show the legend of all categories.')
parser.add_argument(
'--show',
action='store_true',
help='Display the result in a graphical window.')
# t-SNE settings
parser.add_argument(
'--n-components', type=int, default=2, help='the dimension of results')
parser.add_argument(
'--perplexity',
type=float,
default=30.0,
help='The perplexity is related to the number of nearest neighbors'
'that is used in other manifold learning algorithms.')
parser.add_argument(
'--early-exaggeration',
type=float,
default=12.0,
help='Controls how tight natural clusters in the original space are in'
'the embedded space and how much space will be between them.')
parser.add_argument(
'--learning-rate',
type=float,
default=200.0,
help='The learning rate for t-SNE is usually in the range'
'[10.0, 1000.0]. If the learning rate is too high, the data may look'
'like a ball with any point approximately equidistant from its nearest'
'neighbours. If the learning rate is too low, most points may look'
'compressed in a dense cloud with few outliers.')
parser.add_argument(
'--n-iter',
type=int,
default=1000,
help='Maximum number of iterations for the optimization. Should be at'
'least 250.')
parser.add_argument(
'--n-iter-without-progress',
type=int,
default=300,
help='Maximum number of iterations without progress before we abort'
'the optimization.')
parser.add_argument(
'--init', type=str, default='random', help='The init method')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
work_type = args.config.split('/')[1]
cfg.work_dir = osp.join('./work_dirs', work_type,
osp.splitext(osp.basename(args.config))[0])
# create work_dir
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
tsne_work_dir = osp.join(cfg.work_dir, f'tsne_{timestamp}/')
mkdir_or_exist(osp.abspath(tsne_work_dir))
# init the logger before other steps
log_file = osp.join(tsne_work_dir, 'tsne.log')
logger = MMLogger.get_instance(
'mmpretrain',
logger_name='mmpretrain',
log_file=log_file,
log_level=cfg.log_level)
# build the model from a config file and a checkpoint file
device = args.device or get_device()
model = get_model(cfg, args.checkpoint, device=device)
logger.info('Model loaded.')
# build the dataset
if args.test_cfg is not None:
dataloader_cfg = Config.fromfile(args.test_cfg).get('test_dataloader')
elif 'test_dataloader' not in cfg:
raise ValueError('No `test_dataloader` in the config, you can '
'specify another config file that includes test '
'dataloader settings by the `--test-cfg` option.')
else:
dataloader_cfg = cfg.get('test_dataloader')
dataset = DATASETS.build(dataloader_cfg.pop('dataset'))
classes = dataset.metainfo.get('classes')
if args.class_idx is None:
num_classes = args.max_num_class if classes is None else len(classes)
args.class_idx = list(range(num_classes))[:args.max_num_class]
if classes is not None:
classes = [classes[idx] for idx in args.class_idx]
else:
classes = args.class_idx
# compress dataset, select that the label is less then max_num_class
subset_idx_list = []
counter = defaultdict(int)
for i in range(len(dataset)):
gt_label = dataset.get_data_info(i)['gt_label']
if (gt_label in args.class_idx
and counter[gt_label] < args.max_num_samples):
subset_idx_list.append(i)
counter[gt_label] += 1
dataset.get_subset_(subset_idx_list)
logger.info(f'Apply t-SNE to visualize {len(subset_idx_list)} samples.')
dataloader_cfg.dataset = dataset
dataloader_cfg.setdefault('collate_fn', dict(type='default_collate'))
dataloader = Runner.build_dataloader(dataloader_cfg)
results = dict()
features = []
labels = []
for data in progress.track(dataloader, description='Calculating...'):
with torch.no_grad():
# preprocess data
data = model.data_preprocessor(data)
batch_inputs, batch_data_samples = \
data['inputs'], data['data_samples']
batch_labels = torch.cat([i.gt_label for i in batch_data_samples])
# extract backbone features
extract_args = {}
if args.vis_stage:
extract_args['stage'] = args.vis_stage
batch_features = model.extract_feat(batch_inputs, **extract_args)
# post process
if batch_features[0].ndim == 4:
# For (N, C, H, W) feature
batch_features = [
F.adaptive_avg_pool2d(inputs, 1).squeeze()
for inputs in batch_features
]
elif batch_features[0].ndim == 3:
# For (N, L, C) feature
batch_features = [inputs.mean(1) for inputs in batch_features]
# save batch features
features.append(batch_features)
labels.extend(batch_labels.cpu().numpy())
for i in range(len(features[0])):
key = 'feat_' + str(model.backbone.out_indices[i])
results[key] = np.concatenate(
[batch[i].cpu().numpy() for batch in features], axis=0)
# save features
for key, val in results.items():
output_file = f'{tsne_work_dir}{key}.npy'
np.save(output_file, val)
# build t-SNE model
tsne_model = TSNE(
n_components=args.n_components,
perplexity=args.perplexity,
early_exaggeration=args.early_exaggeration,
learning_rate=args.learning_rate,
n_iter=args.n_iter,
n_iter_without_progress=args.n_iter_without_progress,
init=args.init)
# run and get results
logger.info('Running t-SNE.')
for key, val in results.items():
result = tsne_model.fit_transform(val)
res_min, res_max = result.min(0), result.max(0)
res_norm = (result - res_min) / (res_max - res_min)
_, ax = plt.subplots(figsize=(10, 10))
scatter = ax.scatter(
res_norm[:, 0],
res_norm[:, 1],
alpha=1.0,
s=15,
c=labels,
cmap='tab20')
if args.legend:
legend = ax.legend(scatter.legend_elements()[0], classes)
ax.add_artist(legend)
plt.savefig(f'{tsne_work_dir}{key}.png')
if args.show:
plt.show()
logger.info(f'Save features and results to {tsne_work_dir}')
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment