Commit 1baf0566 authored by limm's avatar limm
Browse files

add tests part

parent 495d9ed9
Pipeline #2800 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import math
from unittest import TestCase
from unittest.mock import patch
import torch
from mmengine.logging import MMLogger
from mmpretrain.datasets import RepeatAugSampler
file = 'mmpretrain.datasets.samplers.repeat_aug.'
class MockDist:
def __init__(self, dist_info=(0, 1), seed=7):
self.dist_info = dist_info
self.seed = seed
def get_dist_info(self):
return self.dist_info
def sync_random_seed(self):
return self.seed
def is_main_process(self):
return self.dist_info[0] == 0
class TestRepeatAugSampler(TestCase):
def setUp(self):
self.data_length = 100
self.dataset = list(range(self.data_length))
@patch(file + 'get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0)
self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_samples, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples, self.data_length)
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(list(sampler), indices[:self.data_length])
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
sampler = RepeatAugSampler(self.dataset, shuffle=False)
self.assertIn('always picks a fixed part', log.output[0])
@patch(file + 'get_dist_info', return_value=(2, 3))
@patch(file + 'is_main_process', return_value=False)
def test_dist(self, mock1, mock2):
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2)
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples,
math.ceil(self.data_length / 3))
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(
list(sampler), indices[2::3][:sampler.num_selected_samples])
logger = MMLogger.get_current_instance()
with patch.object(logger, 'warning') as mock_log:
sampler = RepeatAugSampler(self.dataset, shuffle=False)
mock_log.assert_not_called()
@patch(file + 'get_dist_info', return_value=(0, 1))
@patch(file + 'sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = RepeatAugSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7)
# test random seed
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(10)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = [x for x in indices
for _ in range(3)][:sampler.num_selected_samples]
self.assertEqual(list(sampler), indices)
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(42 + 10)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = [x for x in indices
for _ in range(3)][:sampler.num_selected_samples]
self.assertEqual(list(sampler), indices)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import unittest
import mmcv
import numpy as np
import torch
from PIL import Image
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample, MultiTaskDataSample
class TestPackInputs(unittest.TestCase):
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {
'sample_idx': 1,
'img_path': img_path,
'ori_shape': (300, 400),
'img_shape': (300, 400),
'scale_factor': 1.0,
'flip': False,
'img': mmcv.imread(img_path),
'gt_label': 2,
'custom_key': torch.tensor([1, 2, 3])
}
cfg = dict(type='PackInputs', algorithm_keys=['custom_key'])
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], DataSample)
self.assertIn('flip', results['data_samples'].metainfo_keys())
self.assertIsInstance(results['data_samples'].gt_label, torch.Tensor)
self.assertIsInstance(results['data_samples'].custom_key, torch.Tensor)
# Test grayscale image
data['img'] = data['img'].mean(-1)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (1, 300, 400))
# Test video input
data['img'] = np.random.randint(
0, 256, (10, 3, 1, 224, 224), dtype=np.uint8)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (10, 3, 1, 224, 224))
# Test Pillow input
data['img'] = Image.open(img_path)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (3, 300, 400))
# Test without `img` and `gt_label`
del data['img']
del data['gt_label']
results = transform(copy.deepcopy(data))
self.assertNotIn('gt_label', results['data_samples'])
def test_repr(self):
cfg = dict(type='PackInputs', meta_keys=['flip', 'img_shape'])
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform), "PackInputs(input_key='img', algorithm_keys=(), "
"meta_keys=['flip', 'img_shape'])")
class TestTranspose(unittest.TestCase):
def test_transform(self):
cfg = dict(type='Transpose', keys=['img'], order=[2, 0, 1])
transform = TRANSFORMS.build(cfg)
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
results = transform(copy.deepcopy(data))
self.assertEqual(results['img'].shape, (3, 224, 224))
def test_repr(self):
cfg = dict(type='Transpose', keys=['img'], order=(2, 0, 1))
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform), "Transpose(keys=['img'], order=(2, 0, 1))")
class TestToPIL(unittest.TestCase):
def test_transform(self):
cfg = dict(type='ToPIL')
transform = TRANSFORMS.build(cfg)
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], Image.Image)
cfg = dict(type='ToPIL', to_rgb=True)
transform = TRANSFORMS.build(cfg)
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], Image.Image)
np.equal(np.array(results['img']), data['img'][:, :, ::-1])
def test_repr(self):
cfg = dict(type='ToPIL', to_rgb=True)
transform = TRANSFORMS.build(cfg)
self.assertEqual(repr(transform), 'NumpyToPIL(to_rgb=True)')
class TestToNumpy(unittest.TestCase):
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {
'img': Image.open(img_path),
}
cfg = dict(type='ToNumpy')
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], np.ndarray)
self.assertEqual(results['img'].dtype, 'uint8')
cfg = dict(type='ToNumpy', to_bgr=True)
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], np.ndarray)
self.assertEqual(results['img'].dtype, 'uint8')
np.equal(results['img'], np.array(data['img'])[:, :, ::-1])
def test_repr(self):
cfg = dict(type='ToNumpy', to_bgr=True)
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform), 'PILToNumpy(to_bgr=True, dtype=None)')
class TestCollect(unittest.TestCase):
def test_transform(self):
data = {'img': [1, 2, 3], 'gt_label': 1}
cfg = dict(type='Collect', keys=['img'])
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIn('img', results)
self.assertNotIn('gt_label', results)
def test_repr(self):
cfg = dict(type='Collect', keys=['img'])
transform = TRANSFORMS.build(cfg)
self.assertEqual(repr(transform), "Collect(keys=['img'])")
class TestPackMultiTaskInputs(unittest.TestCase):
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {
'sample_idx': 1,
'img_path': img_path,
'ori_shape': (300, 400),
'img_shape': (300, 400),
'scale_factor': 1.0,
'flip': False,
'img': mmcv.imread(img_path),
'gt_label': {
'task1': 1,
'task3': 3
},
}
cfg = dict(type='PackMultiTaskInputs', multi_task_fields=['gt_label'])
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], MultiTaskDataSample)
self.assertIn('flip', results['data_samples'].task1.metainfo_keys())
self.assertIsInstance(results['data_samples'].task1.gt_label,
torch.Tensor)
# Test grayscale image
data['img'] = data['img'].mean(-1)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (1, 300, 400))
# Test without `img` and `gt_label`
del data['img']
del data['gt_label']
results = transform(copy.deepcopy(data))
self.assertNotIn('gt_label', results['data_samples'])
def test_repr(self):
cfg = dict(
type='PackMultiTaskInputs',
multi_task_fields=['gt_label'],
task_handlers=dict(task1=dict(type='PackInputs')),
)
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform),
"PackMultiTaskInputs(multi_task_fields=['gt_label'], "
"input_key='img', task_handlers={'task1': PackInputs})")
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.transforms import Resize
from mmpretrain.datasets import GaussianBlur, MultiView, Solarize
def test_multi_view():
original_img = np.ones((4, 4, 3), dtype=np.uint8)
# test 1 pipeline with 2 views
pipeline1 = [
Resize(2),
GaussianBlur(magnitude_range=(0.1, 2), magnitude_std='inf')
]
transform = MultiView([pipeline1], 2)
results = dict(img=original_img)
results = transform(results)
assert len(results['img']) == 2
assert results['img'][0].shape == (2, 2, 3)
transform = MultiView([pipeline1], [2])
results = dict(img=original_img)
results = transform(results)
assert len(results['img']) == 2
assert results['img'][0].shape == (2, 2, 3)
# test 2 pipeline with 3 views
pipeline2 = [
Solarize(thr=128),
GaussianBlur(magnitude_range=(0.1, 2), magnitude_std='inf')
]
transform = MultiView([pipeline1, pipeline2], [1, 2])
results = dict(img=original_img)
results = transform(results)
assert len(results['img']) == 3
assert results['img'][0].shape == (2, 2, 3)
assert results['img'][1].shape == (4, 4, 3)
# test repr
assert isinstance(str(transform), str)
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import numpy as np
import torch
from mmengine.runner import Runner
from torch.utils.data import DataLoader, Dataset
class ExampleDataset(Dataset):
def __init__(self):
self.index = 0
self.metainfo = None
def __getitem__(self, idx):
results = dict(imgs=torch.rand((224, 224, 3)).float(), )
return results
def get_gt_labels(self):
gt_labels = np.array([0, 1, 2, 4, 0, 4, 1, 2, 2, 1])
return gt_labels
def __len__(self):
return 10
class TestSetAdaptiveMarginsHook(TestCase):
DEFAULT_HOOK_CFG = dict(type='SetAdaptiveMarginsHook')
DEFAULT_MODEL = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=34,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(type='ArcFaceClsHead', in_channels=512, num_classes=5))
def test_before_train(self):
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=None,
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='VisualizationHook', enable=False),
)
tmpdir = tempfile.TemporaryDirectory()
loader = DataLoader(ExampleDataset(), batch_size=2)
self.runner = Runner(
model=self.DEFAULT_MODEL,
work_dir=tmpdir.name,
train_dataloader=loader,
train_cfg=dict(by_epoch=True, max_epochs=1),
log_level='WARNING',
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
param_scheduler=dict(
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
default_scope='mmpretrain',
default_hooks=default_hooks,
experiment_name='test_construct_with_arcface',
custom_hooks=[self.DEFAULT_HOOK_CFG])
default_margins = torch.tensor([0.5] * 5)
torch.allclose(self.runner.model.head.margins.cpu(), default_margins)
self.runner.call_hook('before_train')
# counts = [2 ,3 , 3, 0, 2] -> [2 ,3 , 3, 1, 2] at least occur once
# feqercy**-0.25 = [0.84089642, 0.75983569, 0.75983569, 1., 0.84089642]
# normized = [0.33752196, 0. , 0. , 1. , 0.33752196]
# margins = [0.20188488, 0.05, 0.05, 0.5, 0.20188488]
expert_margins = torch.tensor(
[0.20188488, 0.05, 0.05, 0.5, 0.20188488])
torch.allclose(self.runner.model.head.margins.cpu(), expert_margins)
model_cfg = {**self.DEFAULT_MODEL}
model_cfg['head'] = dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
)
self.runner = Runner(
model=model_cfg,
work_dir=tmpdir.name,
train_dataloader=loader,
train_cfg=dict(by_epoch=True, max_epochs=1),
log_level='WARNING',
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
param_scheduler=dict(
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
default_scope='mmpretrain',
default_hooks=default_hooks,
experiment_name='test_construct_wo_arcface',
custom_hooks=[self.DEFAULT_HOOK_CFG])
with self.assertRaises(ValueError):
self.runner.call_hook('before_train')
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, patch
from mmpretrain.engine import ClassNumCheckHook
class TestClassNumCheckHook(TestCase):
def setUp(self):
self.runner = MagicMock()
self.dataset = MagicMock()
self.hook = ClassNumCheckHook()
def test_check_head(self):
# check sequence of string
with self.assertRaises(AssertionError):
self.hook._check_head(self.runner, self.dataset)
# check no CLASSES
with patch.object(self.runner.logger, 'warning') as mock:
self.dataset.CLASSES = None
self.hook._check_head(self.runner, self.dataset)
mock.assert_called_once()
# check no modules
self.dataset.CLASSES = ['str'] * 10
self.hook._check_head(self.runner, self.dataset)
# check number of classes not match
self.dataset.CLASSES = ['str'] * 10
module1 = MagicMock(spec_set=True)
module2 = MagicMock(num_classes=5)
self.runner.model.named_modules.return_value = iter([(None, module1),
(None, module2)])
with self.assertRaises(AssertionError):
self.hook._check_head(self.runner, self.dataset)
def test_before_train(self):
with patch.object(self.hook, '_check_head') as mock:
self.hook.before_train(self.runner)
mock.assert_called_once()
def test_before_val(self):
with patch.object(self.hook, '_check_head') as mock:
self.hook.before_val(self.runner)
mock.assert_called_once()
def test_before_test(self):
with patch.object(self.hook, '_check_head') as mock:
self.hook.before_test(self.runner)
mock.assert_called_once()
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner
from mmengine.structures import LabelData
from torch.utils.data import Dataset
from mmpretrain.engine import DenseCLHook
from mmpretrain.models.selfsup import BaseSelfSupervisor
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from mmpretrain.utils import get_ori_model
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = DataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=[self.data[index]], data_samples=data_sample)
@MODELS.register_module()
class DenseCLDummyLayer(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
class ToyModel(BaseSelfSupervisor):
def __init__(self):
super().__init__(backbone=dict(type='DenseCLDummyLayer'))
self.loss_lambda = 0.5
def loss(self, inputs, data_samples):
labels = []
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(inputs[0])
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
class TestDenseCLHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
self.temp_dir.cleanup()
def test_densecl_hook(self):
device = get_device()
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
densecl_hook = DenseCLHook(start_iters=1)
# test DenseCLHook with model wrapper
runner = Runner(
model=toy_model,
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
batch_size=1,
num_workers=0),
optim_wrapper=OptimWrapper(
torch.optim.Adam(toy_model.parameters())),
param_scheduler=dict(type='MultiStepLR', milestones=[1]),
train_cfg=dict(by_epoch=True, max_epochs=2),
custom_hooks=[densecl_hook],
default_hooks=dict(logger=None),
log_processor=dict(window_size=1),
experiment_name='test_densecl_hook',
default_scope='mmpretrain')
runner.train()
if runner.iter >= 1:
assert get_ori_model(runner.model).loss_lambda == 0.5
else:
assert get_ori_model(runner.model).loss_lambda == 0.
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
import tempfile
from collections import OrderedDict
from unittest import TestCase
from unittest.mock import ANY, MagicMock, call
import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.evaluator import Evaluator
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner
from mmengine.testing import assert_allclose
from torch.utils.data import Dataset
from mmpretrain.engine import EMAHook
class SimpleModel(BaseModel):
def __init__(self):
super().__init__()
self.para = nn.Parameter(torch.zeros(1))
def forward(self, *args, mode='tensor', **kwargs):
if mode == 'predict':
return self.para.clone()
elif mode == 'loss':
return {'loss': self.para.mean()}
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(6, 2)
label = torch.ones(6)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class TestEMAHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
state_dict = OrderedDict(
meta=dict(epoch=1, iter=2),
# The actual ema para
state_dict={'para': torch.tensor([1.])},
# The actual original para
ema_state_dict={'module.para': torch.tensor([2.])},
)
self.ckpt = osp.join(self.temp_dir.name, 'ema.pth')
torch.save(state_dict, self.ckpt)
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
self.temp_dir.cleanup()
def test_load_state_dict(self):
device = get_device()
model = SimpleModel().to(device)
ema_hook = EMAHook()
runner = Runner(
model=model,
train_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
optim_wrapper=OptimWrapper(
optimizer=torch.optim.Adam(model.parameters(), lr=0.)),
train_cfg=dict(by_epoch=True, max_epochs=2),
work_dir=self.temp_dir.name,
resume=False,
load_from=self.ckpt,
default_hooks=dict(logger=None),
custom_hooks=[ema_hook],
default_scope='mmpretrain',
experiment_name='load_state_dict')
runner.train()
assert_allclose(runner.model.para, torch.tensor([1.], device=device))
def test_evaluate_on_ema(self):
device = get_device()
model = SimpleModel().to(device)
# Test validate on ema model
evaluator = Evaluator([MagicMock()])
runner = Runner(
model=model,
val_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=evaluator,
val_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=self.ckpt,
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook')],
default_scope='mmpretrain',
experiment_name='validate_on_ema')
runner.val()
evaluator.metrics[0].process.assert_has_calls([
call(ANY, [torch.tensor([1.]).to(device)]),
])
self.assertNotIn(
call(ANY, [torch.tensor([2.]).to(device)]),
evaluator.metrics[0].process.mock_calls)
# Test test on ema model
evaluator = Evaluator([MagicMock()])
runner = Runner(
model=model,
test_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=self.ckpt,
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook')],
default_scope='mmpretrain',
experiment_name='test_on_ema')
runner.test()
evaluator.metrics[0].process.assert_has_calls([
call(ANY, [torch.tensor([1.]).to(device)]),
])
self.assertNotIn(
call(ANY, [torch.tensor([2.]).to(device)]),
evaluator.metrics[0].process.mock_calls)
# Test validate on both models
evaluator = Evaluator([MagicMock()])
runner = Runner(
model=model,
val_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_evaluator=evaluator,
val_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=self.ckpt,
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', evaluate_on_origin=True)],
default_scope='mmpretrain',
experiment_name='validate_on_ema_false',
)
runner.val()
evaluator.metrics[0].process.assert_has_calls([
call(ANY, [torch.tensor([1.]).to(device)]),
call(ANY, [torch.tensor([2.]).to(device)]),
])
# Test test on both models
evaluator = Evaluator([MagicMock()])
runner = Runner(
model=model,
test_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=self.ckpt,
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', evaluate_on_origin=True)],
default_scope='mmpretrain',
experiment_name='test_on_ema_false',
)
runner.test()
evaluator.metrics[0].process.assert_has_calls([
call(ANY, [torch.tensor([1.]).to(device)]),
call(ANY, [torch.tensor([2.]).to(device)]),
])
# Test evaluate_on_ema=False
evaluator = Evaluator([MagicMock()])
with self.assertWarnsRegex(UserWarning, 'evaluate_on_origin'):
runner = Runner(
model=model,
test_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=self.ckpt,
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', evaluate_on_ema=False)],
default_scope='mmpretrain',
experiment_name='not_test_on_ema')
runner.test()
evaluator.metrics[0].process.assert_has_calls([
call(ANY, [torch.tensor([2.]).to(device)]),
])
self.assertNotIn(
call(ANY, [torch.tensor([1.]).to(device)]),
evaluator.metrics[0].process.mock_calls)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock
import torch
from mmpretrain.engine import PrepareProtoBeforeValLoopHook
from mmpretrain.models.retrievers import BaseRetriever
class ToyRetriever(BaseRetriever):
def forward(self, inputs, data_samples=None, mode: str = 'loss'):
self.prototype_inited is False
def prepare_prototype(self):
"""Preprocessing the prototype before predict."""
self.prototype_vecs = torch.tensor([0])
self.prototype_inited = True
class TestPrepareProtBeforeValLoopHook(TestCase):
def setUp(self):
self.hook = PrepareProtoBeforeValLoopHook
self.runner = MagicMock()
self.runner.model = ToyRetriever()
def test_before_val(self):
self.runner.model.prepare_prototype()
self.assertTrue(self.runner.model.prototype_inited)
self.hook.before_val(self, self.runner)
self.assertIsNotNone(self.runner.model.prototype_vecs)
self.assertTrue(self.runner.model.prototype_inited)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from torch import nn
from mmpretrain.engine import LearningRateDecayOptimWrapperConstructor
from mmpretrain.models import ImageClassifier, VisionTransformer
class ToyViTBackbone(nn.Module):
get_layer_depth = VisionTransformer.get_layer_depth
def __init__(self, num_layers=2):
super().__init__()
self.cls_token = nn.Parameter(torch.ones(1))
self.pos_embed = nn.Parameter(torch.ones(1))
self.num_layers = num_layers
self.layers = nn.ModuleList()
for _ in range(num_layers):
layer = nn.Conv2d(3, 3, 1)
self.layers.append(layer)
class ToyViT(nn.Module):
get_layer_depth = ImageClassifier.get_layer_depth
def __init__(self):
super().__init__()
# add some variables to meet unit test coverate rate
self.backbone = ToyViTBackbone()
self.head = nn.Linear(1, 1)
class TestLearningRateDecayOptimWrapperConstructor(TestCase):
base_lr = 1.0
base_wd = 0.05
def test_add_params(self):
model = ToyViT()
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=self.base_lr,
betas=(0.9, 0.999),
weight_decay=self.base_wd))
paramwise_cfg = dict(
layer_decay_rate=2.0,
bias_decay_mult=0.,
custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
})
constructor = LearningRateDecayOptimWrapperConstructor(
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg,
)
optimizer_wrapper = constructor(model)
expected_groups = [{
'weight_decay': 0.0,
'lr': 8 * self.base_lr,
'param_name': 'backbone.cls_token',
}, {
'weight_decay': 0.0,
'lr': 8 * self.base_lr,
'param_name': 'backbone.pos_embed',
}, {
'weight_decay': self.base_wd,
'lr': 4 * self.base_lr,
'param_name': 'backbone.layers.0.weight',
}, {
'weight_decay': 0.0,
'lr': 4 * self.base_lr,
'param_name': 'backbone.layers.0.bias',
}, {
'weight_decay': self.base_wd,
'lr': 2 * self.base_lr,
'param_name': 'backbone.layers.1.weight',
}, {
'weight_decay': 0.0,
'lr': 2 * self.base_lr,
'param_name': 'backbone.layers.1.bias',
}, {
'weight_decay': self.base_wd,
'lr': 1 * self.base_lr,
'param_name': 'head.weight',
}, {
'weight_decay': 0.0,
'lr': 1 * self.base_lr,
'param_name': 'head.bias',
}]
self.assertIsInstance(optimizer_wrapper.optimizer, torch.optim.AdamW)
self.assertEqual(optimizer_wrapper.optimizer.defaults['lr'],
self.base_lr)
self.assertEqual(optimizer_wrapper.optimizer.defaults['weight_decay'],
self.base_wd)
param_groups = optimizer_wrapper.optimizer.param_groups
self.assertEqual(len(param_groups), len(expected_groups))
for expect, param in zip(expected_groups, param_groups):
self.assertEqual(param['param_name'], expect['param_name'])
self.assertEqual(param['lr'], expect['lr'])
self.assertEqual(param['weight_decay'], expect['weight_decay'])
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment