Unverified Commit e9ea1853 authored by Maze's avatar Maze Committed by GitHub
Browse files

Add official pretrained weights for Autoformer (#5100)

parent 16c7f0d0
...@@ -73,6 +73,7 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F ...@@ -73,6 +73,7 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F
sha256.update(chunk) sha256.update(chunk)
pbar.update(len(chunk)) pbar.update(len(chunk))
f.flush() f.flush()
f.close()
else: else:
raise FileNotFoundError( raise FileNotFoundError(
'Download is not enabled, and file does not exist: {}. Please set download=True.'.format(local_path) 'Download is not enabled, and file does not exist: {}. Please set download=True.'.format(local_path)
......
...@@ -398,14 +398,101 @@ class AutoformerSpace(nn.Module): ...@@ -398,14 +398,101 @@ class AutoformerSpace(nn.Module):
def get_extra_mutation_hooks(cls): def get_extra_mutation_hooks(cls):
return [MixedAbsPosEmbed.mutate, MixedClsToken.mutate] return [MixedAbsPosEmbed.mutate, MixedClsToken.mutate]
@classmethod
def preset(cls, name: str):
"""Get the model space config proposed in paper."""
name = name.lower()
assert name in ['tiny', 'small', 'base']
init_kwargs = {'qkv_bias': True, 'drop_rate': 0.0, 'drop_path_rate': 0.1, 'global_pool': True, 'num_classes': 1000}
if name == 'tiny':
init_kwargs.update({
'search_embed_dim': (192, 216, 240),
'search_mlp_ratio': (3.0, 3.5, 4.0),
'search_num_heads': (3, 4),
'search_depth': (12, 13, 14),
})
elif name == 'small':
init_kwargs.update({
'search_embed_dim': (320, 384, 448),
'search_mlp_ratio': (3.0, 3.5, 4.0),
'search_num_heads': (5, 6, 7),
'search_depth': (12, 13, 14),
})
elif name == 'base':
init_kwargs.update({
'search_embed_dim': (528, 576, 624),
'search_mlp_ratio': (3.0, 3.5, 4.0),
'search_num_heads': (8, 9, 10),
'search_depth': (14, 15, 16),
})
else:
raise ValueError(f'Unsupported architecture with name: {name}')
return init_kwargs
@classmethod
def load_strategy_checkpoint(cls, name: str, download: bool = True, progress: bool = True):
"""
Load the RandomOneShot strategy initialized with supernet weights.
Parameters
----------
name : str
Search space size, must be one of {'random-one-shot-tiny', 'random-one-shot-small', 'random-one-shot-base'}.
download : bool
Whether to download supernet weights. Default is ``True``.
progress : bool
Whether to display the download progress. Default is ``True``.
Returns
-------
BaseStrategy
The RandomOneShot strategy initialized with supernet weights provided in the official repo.
"""
legal = ['random-one-shot-tiny', 'random-one-shot-small', 'random-one-shot-base']
if name not in legal:
raise ValueError(f'Unsupported name: {name}. It should be one of {legal}.')
name = name[16:]
from nni.nas.strategy import RandomOneShot
init_kwargs = cls.preset(name)
model_sapce = cls(**init_kwargs)
strategy = RandomOneShot(mutation_hooks=cls.get_extra_mutation_hooks())
strategy.attach_model(model_sapce)
weight_file = load_pretrained_weight(f"autoformer-{name}-supernet", download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
assert strategy.model is not None
strategy.model.load_state_dict(pretrained_weights)
return strategy
@classmethod @classmethod
def load_searched_model( def load_searched_model(
cls, name: str, cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True pretrained: bool = False, download: bool = True, progress: bool = True
) -> nn.Module: ) -> nn.Module:
"""
Load the searched subnet model.
init_kwargs = {'qkv_bias': True, 'drop_rate': 0.0, 'drop_path_rate': 0.1, 'global_pool': True, 'num_classes': 1000} Parameters
if name == 'autoformer-tiny': ----------
name : str
Search space size, must be one of {'autoformer-tiny', 'autoformer-small', 'autoformer-base'}.
pretrained : bool
Whether initialized with pre-trained weights. Default is ``False``.
download : bool
Whether to download supernet weights. Default is ``True``.
progress : bool
Whether to display the download progress. Default is ``True``.
Returns
-------
nn.Module
The subnet model.
"""
legal = ['autoformer-tiny', 'autoformer-small', 'autoformer-base']
if name not in legal:
raise ValueError(f'Unsupported name: {name}. It should be one of {legal}.')
name = name[11:]
init_kwargs = cls.preset(name)
if name == 'tiny':
mlp_ratio = [3.5, 3.5, 3.0, 3.5, 3.0, 3.0, 4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5] + [3.0] mlp_ratio = [3.5, 3.5, 3.0, 3.5, 3.0, 3.0, 4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5] + [3.0]
num_head = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3] + [3] num_head = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3] + [3]
arch: Dict[str, Any] = { arch: Dict[str, Any] = {
...@@ -415,14 +502,7 @@ class AutoformerSpace(nn.Module): ...@@ -415,14 +502,7 @@ class AutoformerSpace(nn.Module):
for i in range(14): for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i] arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i] arch[f'num_head_{i}'] = num_head[i]
elif name == 'small':
init_kwargs.update({
'search_embed_dim': (240, 216, 192),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (4, 3),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-small':
mlp_ratio = [3.0, 3.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.5, 4.0] + [3.0] mlp_ratio = [3.0, 3.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.5, 4.0] + [3.0]
num_head = [6, 6, 5, 7, 5, 5, 5, 6, 6, 7, 7, 6, 7] + [5] num_head = [6, 6, 5, 7, 5, 5, 5, 6, 6, 7, 7, 6, 7] + [5]
arch: Dict[str, Any] = { arch: Dict[str, Any] = {
...@@ -432,15 +512,7 @@ class AutoformerSpace(nn.Module): ...@@ -432,15 +512,7 @@ class AutoformerSpace(nn.Module):
for i in range(14): for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i] arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i] arch[f'num_head_{i}'] = num_head[i]
elif name == 'base':
init_kwargs.update({
'search_embed_dim': (448, 384, 320),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (7, 6, 5),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-base':
mlp_ratio = [3.5, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.5] + [3.0, 3.0] mlp_ratio = [3.5, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.5] + [3.0, 3.0]
num_head = [9, 9, 9, 9, 9, 10, 9, 9, 10, 9, 10, 9, 9, 10] + [8, 8] num_head = [9, 9, 9, 9, 9, 10, 9, 9, 10, 9, 10, 9, 9, 10] + [8, 8]
arch: Dict[str, Any] = { arch: Dict[str, Any] = {
...@@ -450,13 +522,6 @@ class AutoformerSpace(nn.Module): ...@@ -450,13 +522,6 @@ class AutoformerSpace(nn.Module):
for i in range(16): for i in range(16):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i] arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i] arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (624, 576, 528),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (10, 9, 8),
'search_depth': (16, 15, 14),
})
else: else:
raise ValueError(f'Unsupported architecture with name: {name}') raise ValueError(f'Unsupported architecture with name: {name}')
...@@ -464,7 +529,7 @@ class AutoformerSpace(nn.Module): ...@@ -464,7 +529,7 @@ class AutoformerSpace(nn.Module):
model = model_factory(**init_kwargs) model = model_factory(**init_kwargs)
if pretrained: if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress) weight_file = load_pretrained_weight(f"autoformer-{name}-subnet", download=download, progress=progress)
pretrained_weights = torch.load(weight_file) pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights) model.load_state_dict(pretrained_weights)
......
...@@ -38,10 +38,14 @@ PRETRAINED_WEIGHT_URLS = { ...@@ -38,10 +38,14 @@ PRETRAINED_WEIGHT_URLS = {
# spos # spos
'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth', 'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth',
# autoformer # autoformer subnet
'autoformer-tiny': f'{NNI_BLOB}/nashub/autoformer-searched-tiny-1e90ebc1.pth', 'autoformer-tiny-subnet': f'{NNI_BLOB}/nashub/autoformer-tiny-subnet-12ed42ff.pth',
'autoformer-small': f'{NNI_BLOB}/nashub/autoformer-searched-small-4bc5d4e5.pth', 'autoformer-small-subnet': f'{NNI_BLOB}/nashub/autoformer-small-subnet-b4e25a1b.pth',
'autoformer-base': f'{NNI_BLOB}/nashub/autoformer-searched-base-c417590a.pth' 'autoformer-base-subnet': f'{NNI_BLOB}/nashub/autoformer-base-subnet-85105f76.pth',
# autoformer supernet
'autoformer-tiny-supernet': f'{NNI_BLOB}/nashub/autoformer-tiny-supernet-6f107004.pth',
'autoformer-small-supernet': f'{NNI_BLOB}/nashub/autoformer-small-supernet-8ed79e18.pth',
'autoformer-base-supernet': f'{NNI_BLOB}/nashub/autoformer-base-supernet-0c6d6612.pth',
} }
......
...@@ -196,3 +196,12 @@ def test_shufflenet(): ...@@ -196,3 +196,12 @@ def test_shufflenet():
def test_autoformer(): def test_autoformer():
ss = searchspace.AutoformerSpace() ss = searchspace.AutoformerSpace()
_test_searchspace_on_dataset(ss, dataset='imagenet') _test_searchspace_on_dataset(ss, dataset='imagenet')
import torch
for name in ['tiny', 'small', 'base']:
# check subnet & supernet weights load
model = searchspace.AutoformerSpace.load_searched_model(f'autoformer-{name}', pretrained = True, download = True)
model(torch.rand(1, 3, 224, 224))
strategy = searchspace.AutoformerSpace.load_strategy_checkpoint(f'random-one-shot-{name}')
strategy.model.resample()
strategy.model(torch.rand(1, 3, 224, 224))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment