Commit 1ac2e802 authored by limm's avatar limm
Browse files

add tools code

parent b6df0d33
Pipeline #2803 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import re
from itertools import groupby
import matplotlib.pyplot as plt
import numpy as np
from mmpretrain.utils import load_json_log
def cal_train_time(log_dicts, args):
"""Compute the average time per training iteration."""
for i, log_dict in enumerate(log_dicts):
print(f'{"-" * 5}Analyze train time of {args.json_logs[i]}{"-" * 5}')
train_logs = log_dict['train']
if 'epoch' in train_logs[0]:
epoch_ave_times = []
for _, logs in groupby(train_logs, lambda log: log['epoch']):
if args.include_outliers:
all_time = np.array([log['time'] for log in logs])
else:
all_time = np.array([log['time'] for log in logs])[1:]
epoch_ave_times.append(all_time.mean())
epoch_ave_times = np.array(epoch_ave_times)
slowest_epoch = epoch_ave_times.argmax()
fastest_epoch = epoch_ave_times.argmin()
std_over_epoch = epoch_ave_times.std()
print(f'slowest epoch {slowest_epoch + 1}, '
f'average time is {epoch_ave_times[slowest_epoch]:.4f}')
print(f'fastest epoch {fastest_epoch + 1}, '
f'average time is {epoch_ave_times[fastest_epoch]:.4f}')
print(f'time std over epochs is {std_over_epoch:.4f}')
avg_iter_time = np.array([log['time'] for log in train_logs]).mean()
print(f'average iter time: {avg_iter_time:.4f} s/iter')
print()
def get_legends(args):
"""if legend is None, use {filename}_{key} as legend."""
legend = args.legend
if legend is None:
legend = []
for json_log in args.json_logs:
for metric in args.keys:
# remove '.json' in the end of log names
basename = os.path.basename(json_log)[:-5]
if basename.endswith('.log'):
basename = basename[:-4]
legend.append(f'{basename}_{metric}')
assert len(legend) == (len(args.json_logs) * len(args.keys))
return legend
def plot_phase_train(metric, train_logs, curve_label):
"""plot phase of train curve."""
xs = np.array([log['step'] for log in train_logs])
ys = np.array([log[metric] for log in train_logs])
if 'epoch' in train_logs[0]:
scale_factor = train_logs[-1]['step'] / train_logs[-1]['epoch']
xs = xs / scale_factor
plt.xlabel('Epochs')
else:
plt.xlabel('Iters')
plt.plot(xs, ys, label=curve_label, linewidth=0.75)
def plot_phase_val(metric, val_logs, curve_label):
"""plot phase of val curve."""
xs = np.array([log['step'] for log in val_logs])
ys = np.array([log[metric] for log in val_logs])
plt.xlabel('Steps')
plt.plot(xs, ys, label=curve_label, linewidth=0.75)
def plot_curve_helper(log_dicts, metrics, args, legend):
"""plot curves from log_dicts by metrics."""
num_metrics = len(metrics)
for i, log_dict in enumerate(log_dicts):
for j, key in enumerate(metrics):
json_log = args.json_logs[i]
print(f'plot curve of {json_log}, metric is {key}')
curve_label = legend[i * num_metrics + j]
train_keys = {} if len(log_dict['train']) == 0 else set(
log_dict['train'][0].keys()) - {'step', 'epoch'}
val_keys = {} if len(log_dict['val']) == 0 else set(
log_dict['val'][0].keys()) - {'step'}
if key in val_keys:
plot_phase_val(key, log_dict['val'], curve_label)
elif key in train_keys:
plot_phase_train(key, log_dict['train'], curve_label)
else:
raise ValueError(
f'Invalid key "{key}", please choose from '
f'{set.union(set(train_keys), set(val_keys))}.')
plt.legend()
def plot_curve(log_dicts, args):
"""Plot train metric-iter graph."""
# set style
try:
import seaborn as sns
sns.set_style(args.style)
except ImportError:
pass
# set plot window size
wind_w, wind_h = args.window_size.split('*')
wind_w, wind_h = int(wind_w), int(wind_h)
plt.figure(figsize=(wind_w, wind_h))
# get legends and metrics
legends = get_legends(args)
metrics = args.keys
# plot curves from log_dicts by metrics
plot_curve_helper(log_dicts, metrics, args, legends)
# set title and show or save
if args.title is not None:
plt.title(args.title)
if args.out is None:
plt.show()
else:
print(f'save curve to: {args.out}')
plt.savefig(args.out)
plt.cla()
def add_plot_parser(subparsers):
parser_plt = subparsers.add_parser(
'plot_curve', help='parser for plotting curves')
parser_plt.add_argument(
'json_logs',
type=str,
nargs='+',
help='path of train log in json format')
parser_plt.add_argument(
'--keys',
type=str,
nargs='+',
default=['loss'],
help='the metric that you want to plot')
parser_plt.add_argument('--title', type=str, help='title of figure')
parser_plt.add_argument(
'--legend',
type=str,
nargs='+',
default=None,
help='legend of each plot')
parser_plt.add_argument(
'--style',
type=str,
default='whitegrid',
help='style of the figure, need `seaborn` package.')
parser_plt.add_argument('--out', type=str, default=None)
parser_plt.add_argument(
'--window-size',
default='12*7',
help='size of the window to display images, in format of "$W*$H".')
def add_time_parser(subparsers):
parser_time = subparsers.add_parser(
'cal_train_time',
help='parser for computing the average time per training iteration')
parser_time.add_argument(
'json_logs',
type=str,
nargs='+',
help='path of train log in json format')
parser_time.add_argument(
'--include-outliers',
action='store_true',
help='include the first value of every epoch when computing '
'the average time')
def parse_args():
parser = argparse.ArgumentParser(description='Analyze Json Log')
# currently only support plot curve and calculate average train time
subparsers = parser.add_subparsers(dest='task', help='task parser')
add_plot_parser(subparsers)
add_time_parser(subparsers)
args = parser.parse_args()
if hasattr(args, 'window_size') and args.window_size != '':
assert re.match(r'\d+\*\d+', args.window_size), \
"'window-size' must be in format 'W*H'."
return args
def main():
args = parse_args()
json_logs = args.json_logs
for json_log in json_logs:
assert json_log.endswith('.json')
log_dicts = [load_json_log(json_log) for json_log in json_logs]
if args.task == 'cal_train_time':
cal_train_time(log_dicts, args)
elif args.task == 'plot_curve':
plot_curve(log_dicts, args)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from pathlib import Path
import mmcv
import mmengine
import torch
from mmengine import DictAction
from mmpretrain.datasets import build_dataset
from mmpretrain.structures import DataSample
from mmpretrain.visualization import UniversalVisualizer
def parse_args():
parser = argparse.ArgumentParser(
description='MMPreTrain evaluate prediction success/fail')
parser.add_argument('config', help='test config file path')
parser.add_argument('result', help='test result json/pkl file')
parser.add_argument(
'--out-dir', required=True, help='dir to store output files')
parser.add_argument(
'--topk',
default=20,
type=int,
help='Number of images to select for success/fail')
parser.add_argument(
'--rescale-factor',
'-r',
type=float,
help='image rescale factor, which is useful if the output is too '
'large or too small.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
full_dir = osp.join(result_dir, folder_name)
vis = UniversalVisualizer()
vis.dataset_meta = {'classes': dataset.CLASSES}
# save imgs
dump_infos = []
for data_sample in results:
data_info = dataset.get_data_info(data_sample.sample_idx)
if 'img' in data_info:
img = data_info['img']
name = str(data_sample.sample_idx)
elif 'img_path' in data_info:
img = mmcv.imread(data_info['img_path'], channel_order='rgb')
name = Path(data_info['img_path']).name
else:
raise ValueError('Cannot load images from the dataset infos.')
if rescale_factor is not None:
img = mmcv.imrescale(img, rescale_factor)
vis.visualize_cls(
img, data_sample, out_file=osp.join(full_dir, name + '.png'))
dump = dict()
for k, v in data_sample.items():
if isinstance(v, torch.Tensor):
dump[k] = v.tolist()
else:
dump[k] = v
dump_infos.append(dump)
mmengine.dump(dump_infos, osp.join(full_dir, folder_name + '.json'))
def main():
args = parse_args()
cfg = mmengine.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# build the dataloader
cfg.test_dataloader.dataset.pipeline = []
dataset = build_dataset(cfg.test_dataloader.dataset)
results = list()
for result in mmengine.load(args.result):
data_sample = DataSample()
data_sample.set_metainfo({'sample_idx': result['sample_idx']})
data_sample.set_gt_label(result['gt_label'])
data_sample.set_pred_label(result['pred_label'])
data_sample.set_pred_score(result['pred_score'])
results.append(data_sample)
# sort result
results = sorted(results, key=lambda x: torch.max(x.pred_score))
success = list()
fail = list()
for data_sample in results:
if (data_sample.pred_label == data_sample.gt_label).all():
success.append(data_sample)
else:
fail.append(data_sample)
success = success[:args.topk]
fail = fail[:args.topk]
save_imgs(args.out_dir, 'success', success, dataset, args.rescale_factor)
save_imgs(args.out_dir, 'fail', fail, dataset, args.rescale_factor)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
import mmengine
from mmengine.config import Config, DictAction
from mmengine.evaluator import Evaluator
from mmengine.runner import Runner
from mmpretrain.evaluation import ConfusionMatrix
from mmpretrain.registry import DATASETS
from mmpretrain.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(
description='Eval a checkpoint and draw the confusion matrix.')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'ckpt_or_result',
type=str,
help='The checkpoint file (.pth) or '
'dumpped predictions pickle file (.pkl).')
parser.add_argument('--out', help='the file to save the confusion matrix.')
parser.add_argument(
'--show',
action='store_true',
help='whether to display the metric result by matplotlib if supports.')
parser.add_argument(
'--show-path', type=str, help='Path to save the visualization image.')
parser.add_argument(
'--include-values',
action='store_true',
help='To draw the values in the figure.')
parser.add_argument(
'--cmap',
type=str,
default='viridis',
help='The color map to use. Defaults to "viridis".')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# register all modules in mmpretrain into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if args.ckpt_or_result.endswith('.pth'):
# Set confusion matrix as the metric.
cfg.test_evaluator = dict(type='ConfusionMatrix')
cfg.load_from = str(args.ckpt_or_result)
with tempfile.TemporaryDirectory() as tmpdir:
cfg.work_dir = tmpdir
runner = Runner.from_cfg(cfg)
classes = runner.test_loop.dataloader.dataset.metainfo.get(
'classes')
cm = runner.test()['confusion_matrix/result']
else:
predictions = mmengine.load(args.ckpt_or_result)
evaluator = Evaluator(ConfusionMatrix())
metrics = evaluator.offline_evaluate(predictions, None)
cm = metrics['confusion_matrix/result']
try:
# Try to build the dataset.
dataset = DATASETS.build({
**cfg.test_dataloader.dataset, 'pipeline': []
})
classes = dataset.metainfo.get('classes')
except Exception:
classes = None
if args.out is not None:
mmengine.dump(cm, args.out)
if args.show or args.show_path is not None:
fig = ConfusionMatrix.plot(
cm,
show=args.show,
classes=classes,
include_values=args.include_values,
cmap=args.cmap)
if args.show_path is not None:
fig.savefig(args.show_path)
print(f'The confusion matrix is saved at {args.show_path}.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import mmengine
import rich
from mmengine import DictAction
from mmengine.evaluator import Evaluator
from mmpretrain.registry import METRICS
HELP_URL = (
'https://mmpretrain.readthedocs.io/en/latest/useful_tools/'
'log_result_analysis.html#how-to-conduct-offline-metric-evaluation')
prog_description = f"""\
Evaluate metric of the results saved in pkl format.
The detailed usage can be found in {HELP_URL}
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('pkl_results', help='Results in pickle format')
parser.add_argument(
'--metric',
nargs='+',
action='append',
dest='metric_options',
help='The metric config, the key-value pair in xxx=yyy format will be '
'parsed as the metric config items. You can specify multiple metrics '
'by use multiple `--metric`. For list type value, you can use '
'"key=[a,b]" or "key=a,b", and it also allows nested list/tuple '
'values, e.g. "key=[(a,b),(c,d)]".')
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.metric_options is None:
raise ValueError('Please speicfy at least one `--metric`. '
f'The detailed usage can be found in {HELP_URL}')
test_metrics = []
for metric_option in args.metric_options:
metric_cfg = {}
for kv in metric_option:
k, v = kv.split('=', maxsplit=1)
metric_cfg[k] = DictAction._parse_iterable(v)
test_metrics.append(METRICS.build(metric_cfg))
predictions = mmengine.load(args.pkl_results)
evaluator = Evaluator(test_metrics)
eval_results = evaluator.offline_evaluate(predictions, None)
rich.print(eval_results)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmengine.analysis import get_model_complexity_info
from mmpretrain import get_model
def parse_args():
parser = argparse.ArgumentParser(description='Get model flops and params')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
model = get_model(args.config)
model.eval()
if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
analysis_results = get_model_complexity_info(
model,
input_shape,
)
flops = analysis_results['flops_str']
params = analysis_results['params_str']
activations = analysis_results['activations_str']
out_table = analysis_results['out_table']
out_arch = analysis_results['out_arch']
print(out_arch)
print(out_table)
split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n'
f'Activation: {activations}\n{split_line}')
print('!!!Only the backbone network is counted in FLOPs analysis.')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
import argparse
import os
import os.path as osp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mmengine.logging import MMLogger
from utils import FormatStrFormatter, ShapeBias
# global default boundary settings for thin gray transparent
# boundaries to avoid not being able to see the difference
# between two partially overlapping datapoints of the same color:
PLOTTING_EDGE_COLOR = (0.3, 0.3, 0.3, 0.3)
PLOTTING_EDGE_WIDTH = 0.02
ICONS_DIR = osp.join(
osp.dirname(__file__), '..', '..', 'resources', 'shape_bias_icons')
parser = argparse.ArgumentParser()
parser.add_argument('--csv-dir', type=str, help='directory of csv files')
parser.add_argument(
'--result-dir', type=str, help='directory to save plotting results')
parser.add_argument('--model-names', nargs='+', default=[], help='model name')
parser.add_argument(
'--colors',
nargs='+',
type=float,
default=[],
help= # noqa
'the colors for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--markers',
nargs='+',
type=str,
default=[],
help= # noqa
'the markers for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--plotting-names',
nargs='+',
default=[],
help= # noqa
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--delete-icons',
action='store_true',
help='whether to delete the icons after plotting')
humans = [
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
]
icon_names = [
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
def read_csvs(csv_dir: str) -> pd.DataFrame:
"""Reads all csv files in a directory and returns a single dataframe.
Args:
csv_dir (str): directory of csv files.
Returns:
pd.DataFrame: dataframe containing all csv files
"""
df = pd.DataFrame()
for csv in os.listdir(csv_dir):
if csv.endswith('.csv'):
cur_df = pd.read_csv(osp.join(csv_dir, csv))
cur_df.columns = [c.lower() for c in cur_df.columns]
df = df.append(cur_df)
df.condition = df.condition.astype(str)
return df
def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
"""Plots a matrixplot of shape bias.
Args:
args (argparse.Namespace): arguments.
analysis (ShapeBias): shape bias analysis. Defaults to ShapeBias().
"""
mpl.rcParams['font.family'] = ['serif']
mpl.rcParams['font.serif'] = ['Times New Roman']
plt.figure(figsize=(9, 7))
df = read_csvs(args.csv_dir)
fontsize = 15
ticklength = 10
markersize = 250
label_size = 20
classes = df['category'].unique()
num_classes = len(classes)
# plot setup
fig = plt.figure(1, figsize=(12, 12), dpi=300.)
ax = plt.gca()
ax.set_xlim([0, 1])
ax.set_ylim([-.5, num_classes - 0.5])
# secondary reversed x axis
ax_top = ax.secondary_xaxis(
'top', functions=(lambda x: 1 - x, lambda x: 1 - x))
# labels, ticks
plt.tick_params(
axis='y', which='both', left=False, right=False, labelleft=False)
ax.set_ylabel('Shape categories', labelpad=60, fontsize=label_size)
ax.set_xlabel(
"Fraction of 'texture' decisions", fontsize=label_size, labelpad=25)
ax_top.set_xlabel(
"Fraction of 'shape' decisions", fontsize=label_size, labelpad=25)
ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax_top.xaxis.set_major_formatter(FormatStrFormatter('%g'))
ax.get_xaxis().set_ticks(np.arange(0, 1.1, 0.1))
ax_top.set_ticks(np.arange(0, 1.1, 0.1))
ax.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
ax_top.tick_params(
axis='both', which='major', labelsize=fontsize, length=ticklength)
# arrows on x axes
plt.arrow(
x=0,
y=-1.75,
dx=1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
plt.arrow(
x=1,
y=num_classes + 0.75,
dx=-1,
dy=0,
fc='black',
head_width=0.4,
head_length=0.03,
clip_on=False,
length_includes_head=True,
overhang=0.5)
# icons besides y axis
# determine order of icons
df_selection = df.loc[(df['subj'].isin(humans))]
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query("category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
sorted_indices = np.argsort(class_avgs)
classes = classes[sorted_indices]
# icon placement is calculated in axis coordinates
WIDTH = 1 / num_classes
# placement left of yaxis (-WIDTH) plus some spacing (-.25*WIDTH)
XPOS = -1.25 * WIDTH
YPOS = -0.5
HEIGHT = 1
MARGINX = 1 / 10 * WIDTH # vertical whitespace between icons
MARGINY = 1 / 10 * HEIGHT # horizontal whitespace between icons
left = XPOS + MARGINX
right = XPOS + WIDTH - MARGINX
for i in range(num_classes):
bottom = i + MARGINY + YPOS
top = (i + 1) - MARGINY + YPOS
iconpath = osp.join(ICONS_DIR, '{}.png'.format(classes[i]))
plt.imshow(
plt.imread(iconpath),
extent=[left, right, bottom, top],
aspect='auto',
clip_on=False)
# plot horizontal intersection lines
for i in range(num_classes - 1):
plt.plot([0, 1], [i + .5, i + .5],
c='gray',
linestyle='dotted',
alpha=0.4)
# plot average shapebias + scatter points
for i in range(len(args.model_names)):
df_selection = df.loc[(df['subj'].isin(args.model_names[i]))]
result_df = analysis.analysis(df=df_selection)
avg = 1 - result_df['shape-bias']
ax.plot([avg, avg], [-1, num_classes], color=args.colors[i])
class_avgs = []
for cl in classes:
df_class_selection = df_selection.query(
"category == '{}'".format(cl))
class_avgs.append(1 - analysis.analysis(
df=df_class_selection)['shape-bias'])
ax.scatter(
class_avgs,
classes,
color=args.colors[i],
marker=args.markers[i],
label=args.plotting_names[i],
s=markersize,
clip_on=False,
edgecolors=PLOTTING_EDGE_COLOR,
linewidths=PLOTTING_EDGE_WIDTH,
zorder=3)
plt.legend(frameon=True, labelspacing=1, loc=9)
figure_path = osp.join(args.result_dir,
'cue-conflict_shape-bias_matrixplot.pdf')
fig.savefig(figure_path, bbox_inches='tight')
plt.close()
def check_icons() -> bool:
"""Check if icons are present, if not download them."""
if not osp.exists(ICONS_DIR):
return False
for icon_name in icon_names:
if not osp.exists(osp.join(ICONS_DIR, icon_name)):
return False
return True
if __name__ == '__main__':
if not check_icons():
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
MMLogger.get_current_instance().info(
f'Downloading icons to {ICONS_DIR}')
for icon_name in icon_names:
url = osp.join(root_url, icon_name)
os.system('wget -O {} {}'.format(
osp.join(ICONS_DIR, icon_name), url))
args = parser.parse_args()
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
must be 3 times the number of models. Every three colors are the RGB \
values for one model.'
# preprocess colors
args.colors = [c / 255. for c in args.colors]
colors = []
for i in range(len(args.model_names)):
colors.append(args.colors[3 * i:3 * i + 3])
args.colors = colors
args.colors.append([165 / 255., 30 / 255., 55 / 255.]) # human color
# if plotting names are not specified, use model names
if len(args.plotting_names) == 0:
args.plotting_names = args.model_names
# preprocess markers
args.markers.append('D') # human marker
# preprocess model names
args.model_names = [[m] for m in args.model_names]
args.model_names.append(humans)
# preprocess plotting names
args.plotting_names.append('Humans')
plot_shape_bias_matrixplot(args)
if args.delete_icons:
os.system('rm -rf {}'.format(ICONS_DIR))
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
from typing import Any, Dict, List, Optional
import matplotlib as mpl
import pandas as pd
from matplotlib import _api
from matplotlib import transforms as mtransforms
class _DummyAxis:
"""Define the minimal interface for a dummy axis.
Args:
minpos (float): The minimum positive value for the axis. Defaults to 0.
"""
__name__ = 'dummy'
# Once the deprecation elapses, replace dataLim and viewLim by plain
# _view_interval and _data_interval private tuples.
dataLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_data_interval() and set_data_interval()')
viewLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_view_interval() and set_view_interval()')
def __init__(self, minpos: float = 0) -> None:
self._dataLim = mtransforms.Bbox.unit()
self._viewLim = mtransforms.Bbox.unit()
self._minpos = minpos
def get_view_interval(self) -> Dict:
"""Return the view interval as a tuple (*vmin*, *vmax*)."""
return self._viewLim.intervalx
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self._viewLim.intervalx = vmin, vmax
def get_minpos(self) -> float:
"""Return the minimum positive value for the axis."""
return self._minpos
def get_data_interval(self) -> Dict:
"""Return the data interval as a tuple (*vmin*, *vmax*)."""
return self._dataLim.intervalx
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self._dataLim.intervalx = vmin, vmax
def get_tick_space(self) -> int:
"""Return the number of ticks to use."""
# Just use the long-standing default of nbins==9
return 9
class TickHelper:
"""A helper class for ticks and tick labels."""
axis = None
def set_axis(self, axis: Any) -> None:
"""Set the axis instance."""
self.axis = axis
def create_dummy_axis(self, **kwargs) -> None:
"""Create a dummy axis if no axis is set."""
if self.axis is None:
self.axis = _DummyAxis(**kwargs)
@_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self.axis.set_view_interval(vmin, vmax)
@_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self.axis.set_data_interval(vmin, vmax)
@_api.deprecated(
'3.5',
alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
def set_bounds(self, vmin: float, vmax: float) -> None:
"""Set the view and data interval to (*vmin*, *vmax*)."""
self.set_view_interval(vmin, vmax)
self.set_data_interval(vmin, vmax)
class Formatter(TickHelper):
"""Create a string based on a tick value and location."""
# some classes want to see all the locs to help format
# individual ones
locs = []
def __call__(self, x: str, pos: Optional[Any] = None) -> str:
"""Return the format for tick value *x* at position pos.
``pos=None`` indicates an unspecified location.
This method must be overridden in the derived class.
Args:
x (str): The tick value.
pos (Optional[Any]): The tick position. Defaults to None.
"""
raise NotImplementedError('Derived must override')
def format_ticks(self, values: pd.Series) -> List[str]:
"""Return the tick labels for all the ticks at once.
Args:
values (pd.Series): The tick values.
Returns:
List[str]: The tick labels.
"""
self.set_locs(values)
return [self(value, i) for i, value in enumerate(values)]
def format_data(self, value: Any) -> str:
"""Return the full string representation of the value with the position
unspecified.
Args:
value (Any): The tick value.
Returns:
str: The full string representation of the value.
"""
return self.__call__(value)
def format_data_short(self, value: Any) -> str:
"""Return a short string version of the tick value.
Defaults to the position-independent long value.
Args:
value (Any): The tick value.
Returns:
str: The short string representation of the value.
"""
return self.format_data(value)
def get_offset(self) -> str:
"""Return the offset string."""
return ''
def set_locs(self, locs: List[Any]) -> None:
"""Set the locations of the ticks.
This method is called before computing the tick labels because some
formatters need to know all tick locations to do so.
"""
self.locs = locs
@staticmethod
def fix_minus(s: str) -> str:
"""Some classes may want to replace a hyphen for minus with the proper
Unicode symbol (U+2212) for typographical correctness.
This is a
helper method to perform such a replacement when it is enabled via
:rc:`axes.unicode_minus`.
Args:
s (str): The string to replace the hyphen with the Unicode symbol.
"""
return (s.replace('-', '\N{MINUS SIGN}')
if mpl.rcParams['axes.unicode_minus'] else s)
def _set_locator(self, locator: Any) -> None:
"""Subclasses may want to override this to set a locator."""
pass
class FormatStrFormatter(Formatter):
"""Use an old-style ('%' operator) format string to format the tick.
The format string should have a single variable format (%) in it.
It will be applied to the value (not the position) of the tick.
Negative numeric values will use a dash, not a Unicode minus; use mathtext
to get a Unicode minus by wrapping the format specifier with $ (e.g.
"$%g$").
Args:
fmt (str): Format string.
"""
def __init__(self, fmt: str) -> None:
self.fmt = fmt
def __call__(self, x: str, pos: Optional[Any]) -> str:
"""Return the formatted label string.
Only the value *x* is formatted. The position is ignored.
Args:
x (str): The value to format.
pos (Any): The position of the tick. Ignored.
"""
return self.fmt % x
class ShapeBias:
"""Compute the shape bias of a model.
Reference: `ImageNet-trained CNNs are biased towards texture;
increasing shape bias improves accuracy and robustness
<https://arxiv.org/abs/1811.12231>`_.
"""
num_input_models = 1
def __init__(self) -> None:
super().__init__()
self.plotting_name = 'shape-bias'
@staticmethod
def _check_dataframe(df: pd.DataFrame) -> None:
"""Check that the dataframe is valid."""
assert len(df) > 0, 'empty dataframe'
def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
"""Compute the shape bias of a model.
Args:
df (pd.DataFrame): The dataframe containing the data.
Returns:
Dict[str, float]: The shape bias.
"""
self._check_dataframe(df)
df = df.copy()
df['correct_texture'] = df['imagename'].apply(
self.get_texture_category)
df['correct_shape'] = df['category']
# remove those rows where shape = texture, i.e. no cue conflict present
df2 = df.loc[df.correct_shape != df.correct_texture]
fraction_correct_shape = len(
df2.loc[df2.object_response == df2.correct_shape]) / len(df)
fraction_correct_texture = len(
df2.loc[df2.object_response == df2.correct_texture]) / len(df)
shape_bias = fraction_correct_shape / (
fraction_correct_shape + fraction_correct_texture)
result_dict = {
'fraction-correct-shape': fraction_correct_shape,
'fraction-correct-texture': fraction_correct_texture,
'shape-bias': shape_bias
}
return result_dict
def get_texture_category(self, imagename: str) -> str:
"""Return texture category from imagename.
e.g. 'XXX_dog10-bird2.png' -> 'bird '
Args:
imagename (str): Name of the image.
Returns:
str: Texture category.
"""
assert type(imagename) is str
# remove unnecessary words
a = imagename.split('_')[-1]
# remove .png etc.
b = a.split('.')[0]
# get texture category (last word)
c = b.split('-')[-1]
# remove number, e.g. 'bird2' -> 'bird'
d = ''.join([i for i in c if not i.isdigit()])
return d
#!/usr/bin/env bash
set -x
CFG=$1
CHECKPOINT=$2
GPUS=$3
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim test mmdet \
$CFG \
--checkpoint $CHECKPOINT \
--launcher pytorch \
-G $GPUS \
$PY_ARGS
#!/usr/bin/env bash
set -x
CFG=$1
PRETRAIN=$2 # pretrained model
GPUS=$3
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim train mmdet $CFG \
--launcher pytorch -G $GPUS \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
model.backbone.init_cfg.prefix="backbone." \
model.roi_head.shared_head.init_cfg.type=Pretrained \
model.roi_head.shared_head.init_cfg.checkpoint=$PRETRAIN \
model.roi_head.shared_head.init_cfg.prefix="backbone." \
$PY_ARGS
#!/usr/bin/env bash
set -x
CFG=$1
PRETRAIN=$2 # pretrained model
GPUS=$3
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim train mmdet $CFG \
--launcher pytorch -G $GPUS \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
model.backbone.init_cfg.prefix="backbone." \
$PY_ARGS
#!/usr/bin/env bash
set -x
PARTITION=$1
CFG=$2
CHECKPOINT=$3
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim test mmdet \
$CFG \
--checkpoint $CHECKPOINT \
--launcher slurm -G $GPUS \
--gpus-per-node $GPUS_PER_NODE \
--cpus-per-task $CPUS_PER_TASK \
--partition $PARTITION \
--srun-args "$SRUN_ARGS" \
$PY_ARGS
#!/usr/bin/env bash
set -x
PARTITION=$1
CFG=$2
PRETRAIN=$3 # pretrained model
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim train mmdet $CFG \
--launcher slurm -G $GPUS \
--gpus-per-node $GPUS_PER_NODE \
--cpus-per-task $CPUS_PER_TASK \
--partition $PARTITION \
--srun-args "$SRUN_ARGS" \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
model.backbone.init_cfg.prefix="backbone." \
model.roi_head.shared_head.init_cfg.type=Pretrained \
model.roi_head.shared_head.init_cfg.checkpoint=$PRETRAIN \
model.roi_head.shared_head.init_cfg.prefix="backbone." \
$PY_ARGS
#!/usr/bin/env bash
set -x
PARTITION=$1
CFG=$2
PRETRAIN=$3 # pretrained model
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim train mmdet $CFG \
--launcher slurm -G $GPUS \
--gpus-per-node $GPUS_PER_NODE \
--cpus-per-task $CPUS_PER_TASK \
--partition $PARTITION \
--srun-args "$SRUN_ARGS" \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
model.backbone.init_cfg.prefix="backbone." \
$PY_ARGS
#!/usr/bin/env bash
set -x
CFG=$1
CHECKPOINT=$2
GPUS=$3
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim test mmseg \
$CFG \
--checkpoint $CHECKPOINT \
--launcher pytorch \
-G $GPUS \
$PY_ARGS
#!/usr/bin/env bash
set -x
CFG=$1
PRETRAIN=$2 # pretrained model
GPUS=$3
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim train mmseg $CFG \
--launcher pytorch -G $GPUS \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
model.backbone.init_cfg.prefix="backbone." \
model.pretrained=None \
$PY_ARGS
#!/usr/bin/env bash
set -x
PARTITION=$1
CFG=$2
CHECKPOINT=$3
GPUS=${GPUS:-4}
GPUS_PER_NODE=${GPUS_PER_NODE:-4}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim test mmseg \
$CFG \
--checkpoint $CHECKPOINT \
--launcher slurm -G $GPUS \
--gpus-per-node $GPUS_PER_NODE \
--cpus-per-task $CPUS_PER_TASK \
--partition $PARTITION \
--srun-args "$SRUN_ARGS" \
$PY_ARGS
#!/usr/bin/env bash
set -x
PARTITION=$1
CFG=$2
PRETRAIN=$3 # pretrained model
GPUS=${GPUS:-4}
GPUS_PER_NODE=${GPUS_PER_NODE:-4}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
mim train mmseg $CFG \
--launcher slurm -G $GPUS \
--gpus-per-node $GPUS_PER_NODE \
--cpus-per-task $CPUS_PER_TASK \
--partition $PARTITION \
--srun-args "$SRUN_ARGS" \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
model.backbone.init_cfg.prefix="backbone." \
model.pretrained=None \
$PY_ARGS
# Copyright (c) OpenMMLab. All rights reserved.
"""Create COCO-Style GT annotations based on raw annotation of Flickr30k.
GT annotations are used for evaluation in image caption task.
"""
import json
def main():
with open('dataset_flickr30k.json', 'r') as f:
annotations = json.load(f)
ann_list = []
img_list = []
splits = ['train', 'val', 'test']
for split in splits:
for img in annotations['images']:
# img_example={
# "sentids": [0, 1, 2],
# "imgid": 0,
# "sentences": [
# {"raw": "Two men in green shirts standing in a yard.",
# "imgid": 0, "sentid": 0},
# {"raw": "A man in a blue shirt standing in a garden.",
# "imgid": 0, "sentid": 1},
# {"raw": "Two friends enjoy time spent together.",
# "imgid": 0, "sentid": 2}
# ],
# "split": "train",
# "filename": "1000092795.jpg"
# },
if img['split'] != split:
continue
img_list.append({'id': img['imgid']})
for sentence in img['sentences']:
ann_info = {
'image_id': img['imgid'],
'id': sentence['sentid'],
'caption': sentence['raw']
}
ann_list.append(ann_info)
json_file = {'annotations': ann_list, 'images': img_list}
# generate flickr30k_train_gt.json, flickr30k_val_gt.json
# and flickr30k_test_gt.json
with open(f'flickr30k_{split}_gt.json', 'w') as f:
json.dump(json_file, f)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
"""SimCLR provides list files for semi-supervised benchmarks
https://github.com/google-research/simclr/tree/master/imagenet_subsets/"""
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description='Convert ImageNet subset lists provided by SimCLR into '
'the required format in MMPretrain.')
parser.add_argument(
'input', help='Input list file, downloaded from SimCLR github repo.')
parser.add_argument(
'output', help='Output list file with the required format.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# create dict with full imagenet annotation file
with open('data/imagenet/meta/train.txt', 'r') as f:
lines = f.readlines()
keys = [line.split('/')[0] for line in lines]
labels = [line.strip().split()[1] for line in lines]
mapping = {}
for k, l in zip(keys, labels):
if k not in mapping:
mapping[k] = l
else:
assert mapping[k] == l
# convert
with open(args.input, 'r') as f:
lines = f.readlines()
fns = [line.strip() for line in lines]
sample_keys = [line.split('_')[0] for line in lines]
sample_labels = [mapping[k] for k in sample_keys]
output_lines = [
f'{k}/{fn} {l}\n' for k, fn, l in zip(sample_keys, fns, sample_labels)
]
with open(args.output, 'w+') as f:
f.writelines(output_lines)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import mmcv
def parse_args():
parser = argparse.ArgumentParser(
description='Convert iNaturalist2018 annotations to MMPretrain format.'
)
parser.add_argument('input', type=str, help='Input annotation json file.')
parser.add_argument('output', type=str, help='Output list file.')
args = parser.parse_args()
return args
def main():
args = parse_args()
data = mmcv.load(args.input)
output_lines = []
for img_item in data['images']:
for ann_item in data['annotations']:
if ann_item['image_id'] == img_item['id']:
output_lines.append(
f"{img_item['file_name']} {ann_item['category_id']}\n")
assert len(output_lines) == len(data['images'])
with open(args.output, 'w') as f:
f.writelines(output_lines)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment