Commit 1baf0566 authored by limm's avatar limm
Browse files

add tests part

parent 495d9ed9
Pipeline #2800 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmpretrain.structures import (batch_label_to_onehot, cat_batch_labels,
tensor_split)
class TestStructureUtils(TestCase):
def test_tensor_split(self):
tensor = torch.tensor([0, 1, 2, 3, 4, 5, 6])
split_indices = [0, 2, 6, 6]
outs = tensor_split(tensor, split_indices)
self.assertEqual(len(outs), len(split_indices) + 1)
self.assertEqual(outs[0].size(0), 0)
self.assertEqual(outs[1].size(0), 2)
self.assertEqual(outs[2].size(0), 4)
self.assertEqual(outs[3].size(0), 0)
self.assertEqual(outs[4].size(0), 1)
tensor = torch.tensor([])
split_indices = [0, 0, 0, 0]
outs = tensor_split(tensor, split_indices)
self.assertEqual(len(outs), len(split_indices) + 1)
def test_cat_batch_labels(self):
labels = [
torch.tensor([1]),
torch.tensor([3, 2]),
torch.tensor([0, 1, 4]),
torch.tensor([], dtype=torch.int64),
torch.tensor([], dtype=torch.int64),
]
batch_label, split_indices = cat_batch_labels(labels)
self.assertEqual(split_indices, [1, 3, 6, 6])
self.assertEqual(len(batch_label), 6)
labels = tensor_split(batch_label, split_indices)
self.assertEqual(labels[0].tolist(), [1])
self.assertEqual(labels[1].tolist(), [3, 2])
self.assertEqual(labels[2].tolist(), [0, 1, 4])
self.assertEqual(labels[3].tolist(), [])
self.assertEqual(labels[4].tolist(), [])
def test_batch_label_to_onehot(self):
labels = [
torch.tensor([1]),
torch.tensor([3, 2]),
torch.tensor([0, 1, 4]),
torch.tensor([], dtype=torch.int64),
torch.tensor([], dtype=torch.int64),
]
batch_label, split_indices = cat_batch_labels(labels)
batch_score = batch_label_to_onehot(
batch_label, split_indices, num_classes=5)
self.assertEqual(batch_score[0].tolist(), [0, 1, 0, 0, 0])
self.assertEqual(batch_score[1].tolist(), [0, 0, 1, 1, 0])
self.assertEqual(batch_score[2].tolist(), [1, 1, 0, 0, 1])
self.assertEqual(batch_score[3].tolist(), [0, 0, 0, 0, 0])
self.assertEqual(batch_score[4].tolist(), [0, 0, 0, 0, 0])
# Copyright (c) OpenMMLab. All rights reserved.
import re
import tempfile
from collections import OrderedDict
from pathlib import Path
from subprocess import PIPE, Popen
from unittest import TestCase
import mmengine
import torch
from mmengine.config import Config
from mmpretrain import ModelHub, get_model
from mmpretrain.structures import DataSample
MMPRE_ROOT = Path(__file__).parent.parent
ASSETS_ROOT = Path(__file__).parent / 'data'
class TestImageDemo(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'demo/image_demo.py',
'demo/demo.JPEG',
'mobilevit-xxsmall_3rdparty_in1k',
'--device',
'cpu',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('sea snake', out.decode())
class TestAnalyzeLogs(TestCase):
def setUp(self):
self.log_file = ASSETS_ROOT / 'vis_data.json'
self.tmpdir = tempfile.TemporaryDirectory()
self.out_file = Path(self.tmpdir.name) / 'out.png'
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/analyze_logs.py',
'cal_train_time',
str(self.log_file),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('slowest epoch 2, average time is 0.0219', out.decode())
command = [
'python',
'tools/analysis_tools/analyze_logs.py',
'plot_curve',
str(self.log_file),
'--keys',
'accuracy/top1',
'--out',
str(self.out_file),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn(str(self.log_file), out.decode())
self.assertIn(str(self.out_file), out.decode())
self.assertTrue(self.out_file.exists())
class TestAnalyzeResults(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
dataset_cfg = dict(
type='CustomDataset',
data_root=str(ASSETS_ROOT / 'dataset'),
)
config = Config(dict(test_dataloader=dict(dataset=dataset_cfg)))
self.config_file = self.dir / 'config.py'
config.dump(self.config_file)
results = [{
'gt_label': 1,
'pred_label': 0,
'pred_score': [0.9, 0.1],
'sample_idx': 0,
}, {
'gt_label': 0,
'pred_label': 0,
'pred_score': [0.9, 0.1],
'sample_idx': 1,
}]
self.result_file = self.dir / 'results.pkl'
mmengine.dump(results, self.result_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/analyze_results.py',
str(self.config_file),
str(self.result_file),
'--out-dir',
str(self.tmpdir.name),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
p.communicate()
self.assertTrue((self.dir / 'success/2.jpeg.png').exists())
self.assertTrue((self.dir / 'fail/1.JPG.png').exists())
class TestPrintConfig(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.config_file = MMPRE_ROOT / 'configs/resnet/resnet18_8xb32_in1k.py'
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/misc/print_config.py',
str(self.config_file),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
out = out.decode().strip().replace('\r\n', '\n')
self.assertEqual(out,
Config.fromfile(self.config_file).pretty_text.strip())
class TestVerifyDataset(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
dataset_cfg = dict(
type='CustomDataset',
ann_file=str(self.dir / 'ann.txt'),
pipeline=[dict(type='LoadImageFromFile')],
data_root=str(ASSETS_ROOT / 'dataset'),
)
ann_file = '\n'.join(['a/2.JPG 0', 'b/2.jpeg 1', 'b/subb/3.jpg 1'])
(self.dir / 'ann.txt').write_text(ann_file)
config = Config(dict(train_dataloader=dict(dataset=dataset_cfg)))
self.config_file = Path(self.tmpdir.name) / 'config.py'
config.dump(self.config_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/misc/verify_dataset.py',
str(self.config_file),
'--out-path',
str(self.dir / 'log.log'),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn(
f"{ASSETS_ROOT/'dataset/a/2.JPG'} cannot be read correctly",
out.decode().strip())
self.assertTrue((self.dir / 'log.log').exists())
class TestEvalMetric(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
results = [
DataSample().set_gt_label(1).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
DataSample().set_gt_label(0).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
]
self.result_file = self.dir / 'results.pkl'
mmengine.dump(results, self.result_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/eval_metric.py',
str(self.result_file),
'--metric',
'type=Accuracy',
'topk=1,2',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('accuracy/top1', out.decode())
self.assertIn('accuracy/top2', out.decode())
class TestVisScheduler(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
config = Config.fromfile(MMPRE_ROOT /
'configs/resnet/resnet18_8xb32_in1k.py')
config.param_scheduler = [
dict(
type='LinearLR',
start_factor=0.01,
by_epoch=True,
end=1,
convert_to_iter_based=True),
dict(type='CosineAnnealingLR', by_epoch=True, begin=1),
]
config.work_dir = str(self.dir)
config.train_cfg.max_epochs = 2
self.config_file = Path(self.tmpdir.name) / 'config.py'
config.dump(self.config_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/visualization/vis_scheduler.py',
str(self.config_file),
'--dataset-size',
'100',
'--not-show',
'--save-path',
str(self.dir / 'out.png'),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
p.communicate()
self.assertTrue((self.dir / 'out.png').exists())
class TestPublishModel(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
ckpt = dict(
state_dict=OrderedDict({
'a': torch.tensor(1.),
}),
ema_state_dict=OrderedDict({
'step': 1,
'module.a': torch.tensor(2.),
}))
self.ckpt_file = self.dir / 'ckpt.pth'
torch.save(ckpt, self.ckpt_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/model_converters/publish_model.py',
str(self.ckpt_file),
str(self.ckpt_file),
'--dataset-type',
'ImageNet',
'--no-ema',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('and drop the EMA weights.', out.decode())
self.assertIn('Successfully generated', out.decode())
output_ckpt = re.findall(r'ckpt_\d{8}-\w{8}.pth', out.decode())
self.assertGreater(len(output_ckpt), 0)
output_ckpt = output_ckpt[0]
self.assertTrue((self.dir / output_ckpt).exists())
# The input file won't be overridden.
self.assertTrue(self.ckpt_file.exists())
class TestVisCam(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
model = get_model('mobilevit-xxsmall_3rdparty_in1k')
self.config_file = self.dir / 'config.py'
model._config.dump(self.config_file)
self.ckpt_file = self.dir / 'ckpt.pth'
torch.save(model.state_dict(), self.ckpt_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/visualization/vis_cam.py',
str(ASSETS_ROOT / 'color.jpg'),
str(self.config_file),
str(self.ckpt_file),
'--save-path',
str(self.dir / 'cam.jpg'),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('backbone.conv_1x1_exp.bn', out.decode())
self.assertTrue((self.dir / 'cam.jpg').exists())
class TestConfusionMatrix(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
self.config_file = MMPRE_ROOT / 'configs/resnet/resnet18_8xb32_in1k.py'
results = [
DataSample().set_gt_label(1).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
DataSample().set_gt_label(0).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
]
self.result_file = self.dir / 'results.pkl'
mmengine.dump(results, self.result_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/confusion_matrix.py',
str(self.config_file),
str(self.result_file),
'--out',
str(self.dir / 'result.pkl'),
]
Popen(command, cwd=MMPRE_ROOT, stdout=PIPE).wait()
result = mmengine.load(self.dir / 'result.pkl')
torch.testing.assert_allclose(
result, torch.tensor([
[1, 0, 0],
[1, 0, 0],
[0, 0, 0],
]))
class TestVisTsne(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
config = ModelHub.get('mobilevit-xxsmall_3rdparty_in1k').config
test_dataloader = dict(
batch_size=1,
dataset=dict(
type='CustomDataset',
data_root=str(ASSETS_ROOT / 'dataset'),
pipeline=config.test_dataloader.dataset.pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
)
config.test_dataloader = mmengine.ConfigDict(test_dataloader)
self.config_file = self.dir / 'config.py'
config.dump(self.config_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/visualization/vis_tsne.py',
str(self.config_file),
'--work-dir',
str(self.dir),
'--perplexity',
'2',
]
Popen(command, cwd=MMPRE_ROOT, stdout=PIPE).wait()
self.assertTrue(len(list(self.dir.glob('tsne_*/feat_*.png'))) > 0)
class TestGetFlops(TestCase):
def test_run(self):
command = [
'python',
'tools/analysis_tools/get_flops.py',
'mobilevit-xxsmall_3rdparty_in1k',
]
ret_code = Popen(command, cwd=MMPRE_ROOT).wait()
self.assertEqual(ret_code, 0)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from mmpretrain.utils import load_json_log
def test_load_json_log():
demo_log = """\
{"lr": 0.0001, "data_time": 0.003, "loss": 2.29, "time": 0.010, "epoch": 1, "step": 150}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.28, "time": 0.007, "epoch": 1, "step": 300}
{"lr": 0.0001, "data_time": 0.001, "loss": 2.27, "time": 0.008, "epoch": 1, "step": 450}
{"accuracy/top1": 23.98, "accuracy/top5": 66.05, "step": 1}
{"lr": 0.0001, "data_time": 0.001, "loss": 2.25, "time": 0.014, "epoch": 2, "step": 619}
{"lr": 0.0001, "data_time": 0.000, "loss": 2.24, "time": 0.012, "epoch": 2, "step": 769}
{"lr": 0.0001, "data_time": 0.003, "loss": 2.23, "time": 0.009, "epoch": 2, "step": 919}
{"accuracy/top1": 41.82, "accuracy/top5": 81.26, "step": 2}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.21, "time": 0.007, "epoch": 3, "step": 1088}
{"lr": 0.0001, "data_time": 0.005, "loss": 2.18, "time": 0.009, "epoch": 3, "step": 1238}
{"lr": 0.0001, "data_time": 0.002, "loss": 2.16, "time": 0.008, "epoch": 3, "step": 1388}
{"accuracy/top1": 54.07, "accuracy/top5": 89.80, "step": 3}
""" # noqa: E501
with tempfile.TemporaryDirectory() as tmpdir:
json_log = osp.join(tmpdir, 'scalars.json')
with open(json_log, 'w') as f:
f.write(demo_log)
log_dict = load_json_log(json_log)
assert log_dict.keys() == {'train', 'val'}
assert log_dict['train'][3] == {
'lr': 0.0001,
'data_time': 0.001,
'loss': 2.25,
'time': 0.014,
'epoch': 2,
'step': 619
}
assert log_dict['val'][2] == {
'accuracy/top1': 54.07,
'accuracy/top5': 89.80,
'step': 3
}
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import sys
from unittest import TestCase
from mmengine import DefaultScope
from mmpretrain.utils import register_all_modules
class TestSetupEnv(TestCase):
def test_register_all_modules(self):
from mmpretrain.registry import DATASETS
# not init default scope
sys.modules.pop('mmpretrain.datasets', None)
sys.modules.pop('mmpretrain.datasets.custom', None)
DATASETS._module_dict.pop('CustomDataset', None)
self.assertFalse('CustomDataset' in DATASETS.module_dict)
register_all_modules(init_default_scope=False)
self.assertTrue('CustomDataset' in DATASETS.module_dict)
# init default scope
sys.modules.pop('mmpretrain.datasets')
sys.modules.pop('mmpretrain.datasets.custom')
DATASETS._module_dict.pop('CustomDataset', None)
self.assertFalse('CustomDataset' in DATASETS.module_dict)
register_all_modules(init_default_scope=True)
self.assertTrue('CustomDataset' in DATASETS.module_dict)
self.assertEqual(DefaultScope.get_current_instance().scope_name,
'mmpretrain')
# init default scope when another scope is init
name = f'test-{datetime.datetime.now()}'
DefaultScope.get_instance(name, scope_name='test')
with self.assertWarnsRegex(
Warning,
'The current default scope "test" is not "mmpretrain"'):
register_all_modules(init_default_scope=True)
# Copyright (c) OpenMMLab. All rights reserved.
from mmpretrain import digit_version
def test_digit_version():
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
assert digit_version('1.0') == digit_version('1.0.0')
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
assert digit_version('1.0.0a') < digit_version('1.0.0b')
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
assert digit_version('1.0.0') < digit_version('1.0.0post')
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
import torch
from mmpretrain.structures import DataSample
from mmpretrain.visualization import UniversalVisualizer
class TestUniversalVisualizer(TestCase):
def setUp(self) -> None:
super().setUp()
tmpdir = tempfile.TemporaryDirectory()
self.tmpdir = tmpdir
self.vis = UniversalVisualizer(
save_dir=tmpdir.name,
vis_backends=[dict(type='LocalVisBackend')],
)
def test_visualize_cls(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_cls')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_cls(
image=image,
data_sample=data_sample,
show=True,
name='test_cls',
step=1)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_cls_1.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.visualize_cls(
image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
# Test with dataset_meta
self.vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
def patch_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample)
# Test without pred_label
def patch_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Ground truth: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample, draw_pred=False)
# Test without gt_label
def patch_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample, draw_gt=False)
# Test without score
del data_sample.pred_score
def patch_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample)
# Test adaptive font size
def assert_font_size(target_size):
def draw_texts(text, font_sizes, *_, **__):
self.assertEqual(font_sizes, target_size)
return draw_texts
with patch.object(self.vis, 'draw_texts', assert_font_size(7)):
self.vis.visualize_cls(
np.ones((224, 384, 3), np.uint8), data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(2)):
self.vis.visualize_cls(
np.ones((10, 384, 3), np.uint8), data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(21)):
self.vis.visualize_cls(
np.ones((1000, 1000, 3), np.uint8), data_sample)
# Test rescale image
with patch.object(self.vis, 'draw_texts', assert_font_size(14)):
self.vis.visualize_cls(
np.ones((224, 384, 3), np.uint8),
data_sample,
rescale_factor=2.)
def test_visualize_image_retrieval(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])
class ToyPrototype:
def get_data_info(self, idx):
img_path = osp.join(osp.dirname(__file__), '../data/color.jpg')
return {'img_path': img_path, 'sample_idx': idx}
prototype_dataset = ToyPrototype()
# Test show
def mock_show(drawn_img, win_name, wait_time):
if image.shape == drawn_img.shape:
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_retrieval')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_image_retrieval(
image,
data_sample,
prototype_dataset,
show=True,
name='test_retrieval',
step=1)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_retrieval_1.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.visualize_image_retrieval(
image,
data_sample,
prototype_dataset,
out_file=out_file,
)
self.assertTrue(osp.exists(out_file))
def test_visualize_masked_image(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_mask(
torch.tensor([
[0, 0, 1, 1],
[0, 1, 1, 0],
[1, 1, 0, 0],
[1, 0, 0, 1],
]))
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertTupleEqual(drawn_img.shape, (224, 224, 3))
self.assertEqual(win_name, 'test_mask')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_masked_image(
image, data_sample, show=True, name='test_mask', step=1)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_mask_1.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.visualize_masked_image(image, data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
def tearDown(self):
self.tmpdir.cleanup()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment