Commit 1ac2e802 authored by limm's avatar limm
Browse files

add tools code

parent b6df0d33
Pipeline #2803 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
from mmengine.config import Config
from mmpretrain.registry import MODELS
@torch.no_grad()
def merge_lora_weight(cfg, lora_weight):
"""Merge base weight and lora weight.
Args:
cfg (dict): config for LoRAModel.
lora_weight (dict): weight dict from LoRAModel.
Returns:
Merged weight.
"""
temp = dict()
mapping = dict()
for name, param in lora_weight['state_dict'].items():
# backbone.module.layers.11.attn.qkv.lora_down.weight
if '.lora_' in name:
lora_split = name.split('.')
prefix = '.'.join(lora_split[:-2])
if prefix not in mapping:
mapping[prefix] = dict()
lora_type = lora_split[-2]
mapping[prefix][lora_type] = param
else:
temp[name] = param
model = MODELS.build(cfg['model'])
for name, param in model.named_parameters():
if name in temp or '.lora_' in name:
continue
else:
name_split = name.split('.')
prefix = prefix = '.'.join(name_split[:-2])
if prefix in mapping:
name_split.pop(-2)
if name_split[-1] == 'weight':
scaling = get_scaling(model, prefix)
lora_down = mapping[prefix]['lora_down']
lora_up = mapping[prefix]['lora_up']
param += lora_up @ lora_down * scaling
name_split.pop(1)
name = '.'.join(name_split)
temp[name] = param
result = dict()
result['state_dict'] = temp
result['meta'] = lora_weight['meta']
return result
def get_scaling(model, prefix):
"""Get the scaling of target layer.
Args:
model (LoRAModel): the LoRAModel.
prefix (str): the prefix of the layer.
Returns:
the scale of the LoRALinear.
"""
prefix_split = prefix.split('.')
for i in prefix_split:
model = getattr(model, i)
return model.scaling
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Merge LoRA weight')
parser.add_argument('cfg', help='cfg path')
parser.add_argument('src', help='src lora model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
cfg = Config.fromfile(args.cfg)
lora_model = torch.load(args.src, map_location='cpu')
merged_model = merge_lora_weight(cfg, lora_model)
torch.save(merged_model, args.dst)
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def correct_unfold_reduction_order(x: torch.Tensor):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
def convert_mixmim(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('patch_embed'):
new_k = k.replace('proj', 'projection')
elif k.startswith('layers'):
if 'norm1' in k:
new_k = k.replace('norm1', 'ln1')
elif 'norm2' in k:
new_k = k.replace('norm2', 'ln2')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
else:
new_k = k
elif k.startswith('norm') or k.startswith('absolute_pos_embed'):
new_k = k
elif k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
else:
raise ValueError
# print(new_k)
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
if 'downsample' in new_k:
print('Covert {} in PatchMerging from timm to mmcv format!'.format(
new_k))
if 'reduction' in new_k:
new_v = correct_unfold_reduction_order(new_v)
elif 'norm' in new_k:
new_v = correct_unfold_norm_order(new_v)
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained mixmim '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_mixmim(state_dict)
# weight = convert_official_mixmim(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
def convert_weights(weight):
"""Weight Converter.
Converts the weights from timm to mmpretrain
Args:
weight (dict): weight dict from timm
Returns: converted weight dict for mmpretrain
"""
result = dict()
result['meta'] = dict()
temp = dict()
mapping = {
'stem': 'patch_embed',
'proj': 'projection',
'mlp_tokens.fc1': 'token_mix.layers.0.0',
'mlp_tokens.fc2': 'token_mix.layers.1',
'mlp_channels.fc1': 'channel_mix.layers.0.0',
'mlp_channels.fc2': 'channel_mix.layers.1',
'norm1': 'ln1',
'norm2': 'ln2',
'norm.': 'ln1.',
'blocks': 'layers'
}
for k, v in weight.items():
for mk, mv in mapping.items():
if mk in k:
k = k.replace(mk, mv)
if k.startswith('head.'):
temp['head.fc.' + k[5:]] = v
else:
temp['backbone.' + k] = v
result['state_dict'] = temp
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
original_model = torch.load(args.src, map_location='cpu')
converted_model = convert_weights(original_model)
torch.save(converted_model, args.dst)
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import torch
def convert_conv1(model_key, model_weight, state_dict, converted_names):
if model_key.find('features.0.0') >= 0:
new_key = model_key.replace('features.0.0', 'backbone.conv1.conv')
else:
new_key = model_key.replace('features.0.1', 'backbone.conv1.bn')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_conv5(model_key, model_weight, state_dict, converted_names):
if model_key.find('features.18.0') >= 0:
new_key = model_key.replace('features.18.0', 'backbone.conv2.conv')
else:
new_key = model_key.replace('features.18.1', 'backbone.conv2.bn')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_head(model_key, model_weight, state_dict, converted_names):
new_key = model_key.replace('classifier.1', 'head.fc')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_block(model_key, model_weight, state_dict, converted_names):
split_keys = model_key.split('.')
layer_id = int(split_keys[1])
new_layer_id = 0
sub_id = 0
if layer_id == 1:
new_layer_id = 1
sub_id = 0
elif layer_id in range(2, 4):
new_layer_id = 2
sub_id = layer_id - 2
elif layer_id in range(4, 7):
new_layer_id = 3
sub_id = layer_id - 4
elif layer_id in range(7, 11):
new_layer_id = 4
sub_id = layer_id - 7
elif layer_id in range(11, 14):
new_layer_id = 5
sub_id = layer_id - 11
elif layer_id in range(14, 17):
new_layer_id = 6
sub_id = layer_id - 14
elif layer_id == 17:
new_layer_id = 7
sub_id = 0
new_key = model_key.replace(f'features.{layer_id}',
f'backbone.layer{new_layer_id}.{sub_id}')
if new_layer_id == 1:
if new_key.find('conv.0.0') >= 0:
new_key = new_key.replace('conv.0.0', 'conv.0.conv')
elif new_key.find('conv.0.1') >= 0:
new_key = new_key.replace('conv.0.1', 'conv.0.bn')
elif new_key.find('conv.1') >= 0:
new_key = new_key.replace('conv.1', 'conv.1.conv')
elif new_key.find('conv.2') >= 0:
new_key = new_key.replace('conv.2', 'conv.1.bn')
else:
raise ValueError(f'Unsupported conversion of key {model_key}')
else:
if new_key.find('conv.0.0') >= 0:
new_key = new_key.replace('conv.0.0', 'conv.0.conv')
elif new_key.find('conv.0.1') >= 0:
new_key = new_key.replace('conv.0.1', 'conv.0.bn')
elif new_key.find('conv.1.0') >= 0:
new_key = new_key.replace('conv.1.0', 'conv.1.conv')
elif new_key.find('conv.1.1') >= 0:
new_key = new_key.replace('conv.1.1', 'conv.1.bn')
elif new_key.find('conv.2') >= 0:
new_key = new_key.replace('conv.2', 'conv.2.conv')
elif new_key.find('conv.3') >= 0:
new_key = new_key.replace('conv.3', 'conv.2.bn')
else:
raise ValueError(f'Unsupported conversion of key {model_key}')
print(f'Convert {model_key} to {new_key}')
state_dict[new_key] = model_weight
converted_names.add(model_key)
def convert(src, dst):
"""Convert keys in torchvision pretrained MobileNetV2 models to mmpretrain
style."""
# load pytorch model
blobs = torch.load(src, map_location='cpu')
# convert to pytorch style
state_dict = OrderedDict()
converted_names = set()
for key, weight in blobs.items():
if 'features.0' in key:
convert_conv1(key, weight, state_dict, converted_names)
elif 'classifier' in key:
convert_head(key, weight, state_dict, converted_names)
elif 'features.18' in key:
convert_conv5(key, weight, state_dict, converted_names)
else:
convert_block(key, weight, state_dict, converted_names)
# check if all layers are converted
for key in blobs:
if key not in converted_names:
print(f'not converted: {key}')
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import re
from collections import OrderedDict, namedtuple
from pathlib import Path
import torch
prog_description = """\
Convert OFA official models to MMPretrain format.
"""
MapItem = namedtuple(
'MapItem', 'pattern repl key_action value_action', defaults=[None] * 4)
def convert_by_mapdict(src_dict: dict, map_dict: Path):
dst_dict = OrderedDict()
convert_map_dict = dict()
for k, v in src_dict.items():
ori_k = k
for item in map_dict:
pattern = item.pattern
assert pattern is not None
match = next(re.finditer(pattern, k), None)
if match is None:
continue
match_group = match.groups()
repl = item.repl
key_action = item.key_action
if key_action is not None:
assert callable(key_action)
match_group = key_action(*match_group)
if isinstance(match_group, str):
match_group = (match_group, )
start, end = match.span(0)
if repl is not None:
k = k[:start] + repl.format(*match_group) + k[end:]
else:
for i, sub in enumerate(match_group):
start, end = match.span(i + 1)
k = k[:start] + str(sub) + k[end:]
value_action = item.value_action
if value_action is not None:
assert callable(value_action)
v = value_action(v)
if v is not None:
dst_dict[k] = v
convert_map_dict[k] = ori_k
return dst_dict, convert_map_dict
map_dict = [
# Encoder modules
MapItem(r'\.type_embedding\.', '.embed_type.'),
MapItem(r'\.layernorm_embedding\.', '.embedding_ln.'),
MapItem(r'\.patch_layernorm_embedding\.', '.image_embedding_ln.'),
MapItem(r'encoder.layer_norm\.', 'encoder.final_ln.'),
# Encoder layers
MapItem(r'\.attn_ln\.', '.attn_mid_ln.'),
MapItem(r'\.ffn_layernorm\.', '.ffn_mid_ln.'),
MapItem(r'\.final_layer_norm', '.ffn_ln'),
MapItem(r'encoder.*(\.self_attn\.)', key_action=lambda _: '.attn.'),
MapItem(
r'encoder.*(\.self_attn_layer_norm\.)',
key_action=lambda _: '.attn_ln.'),
# Decoder modules
MapItem(r'\.code_layernorm_embedding\.', '.code_embedding_ln.'),
MapItem(r'decoder.layer_norm\.', 'decoder.final_ln.'),
# Decoder layers
MapItem(r'\.self_attn_ln', '.self_attn_mid_ln'),
MapItem(r'\.cross_attn_ln', '.cross_attn_mid_ln'),
MapItem(r'\.encoder_attn_layer_norm', '.cross_attn_ln'),
MapItem(r'\.encoder_attn', '.cross_attn'),
MapItem(
r'decoder.*(\.self_attn_layer_norm\.)',
key_action=lambda _: '.self_attn_ln.'),
# Remove version key
MapItem(r'version', '', value_action=lambda _: None),
# Add model prefix
MapItem(r'^', 'model.'),
]
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('src', type=str, help='The official checkpoint path.')
parser.add_argument('dst', type=str, help='The save path.')
args = parser.parse_args()
return args
def main():
args = parse_args()
src = torch.load(args.src)
if 'extra_state' in src and 'ema' in src['extra_state']:
print('Use EMA weights.')
src = src['extra_state']['ema']
else:
src = src['model']
dst, _ = convert_by_mapdict(src, map_dict)
torch.save(dst, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_clip(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('visual.conv1'):
new_k = k.replace('conv1', 'patch_embed.projection')
elif k.startswith('visual.positional_embedding'):
new_k = k.replace('positional_embedding', 'pos_embed')
new_v = v.unsqueeze(dim=0)
elif k.startswith('visual.class_embedding'):
new_k = k.replace('class_embedding', 'cls_token')
new_v = v.unsqueeze(dim=0).unsqueeze(dim=0)
elif k.startswith('visual.ln_pre'):
new_k = k.replace('ln_pre', 'pre_norm')
elif k.startswith('visual.transformer.resblocks'):
new_k = k.replace('transformer.resblocks', 'layers')
if 'ln_1' in k:
new_k = new_k.replace('ln_1', 'ln1')
elif 'ln_2' in k:
new_k = new_k.replace('ln_2', 'ln2')
elif 'mlp.c_fc' in k:
new_k = new_k.replace('mlp.c_fc', 'ffn.layers.0.0')
elif 'mlp.c_proj' in k:
new_k = new_k.replace('mlp.c_proj', 'ffn.layers.1')
elif 'attn.in_proj_weight' in k:
new_k = new_k.replace('in_proj_weight', 'qkv.weight')
elif 'attn.in_proj_bias' in k:
new_k = new_k.replace('in_proj_bias', 'qkv.bias')
elif 'attn.out_proj' in k:
new_k = new_k.replace('out_proj', 'proj')
elif k.startswith('visual.ln_post'):
new_k = k.replace('ln_post', 'ln1')
elif k.startswith('visual.proj'):
new_k = k.replace('visual.proj', 'visual_proj.proj')
else:
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained clip '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_clip(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import re
from collections import OrderedDict
from itertools import chain
from pathlib import Path
import torch
prog_description = """\
Convert Official Otter HF models to MMPreTrain format.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'name_or_dir', type=str, help='The Otter HF model name or directory.')
args = parser.parse_args()
return args
def main():
args = parse_args()
if not Path(args.name_or_dir).is_dir():
from huggingface_hub import snapshot_download
ckpt_dir = Path(
snapshot_download(args.name_or_dir, allow_patterns='*.bin'))
name = args.name_or_dir.replace('/', '_')
else:
ckpt_dir = Path(args.name_or_dir)
name = ckpt_dir.name
state_dict = OrderedDict()
for k, v in chain.from_iterable(
torch.load(ckpt).items() for ckpt in ckpt_dir.glob('*.bin')):
adapter_patterns = [
r'^perceiver',
r'lang_encoder.*embed_tokens',
r'lang_encoder.*gated_cross_attn_layer',
r'lang_encoder.*rotary_emb',
]
if not any(re.match(pattern, k) for pattern in adapter_patterns):
# Drop encoder parameters to decrease the size.
continue
# The keys are different between Open-Flamingo and Otter
if 'gated_cross_attn_layer.feed_forward' in k:
k = k.replace('feed_forward', 'ff')
if 'perceiver.layers' in k:
prefix_match = re.match(r'perceiver.layers.\d+.', k)
prefix = k[:prefix_match.end()]
suffix = k[prefix_match.end():]
if 'feed_forward' in k:
k = prefix + '1.' + suffix.replace('feed_forward.', '')
else:
k = prefix + '0.' + suffix
state_dict[k] = v
if len(state_dict) == 0:
raise RuntimeError('No checkpoint found in the specified directory.')
torch.save(state_dict, name + '.pth')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import datetime
import hashlib
import shutil
import warnings
from collections import OrderedDict
from pathlib import Path
import torch
import mmpretrain
def parse_args():
parser = argparse.ArgumentParser(
description='Process a checkpoint to be published')
parser.add_argument('in_file', help='input checkpoint filename')
parser.add_argument('out_file', help='output checkpoint filename')
parser.add_argument(
'--no-ema',
action='store_true',
help='Use keys in `ema_state_dict` (no-ema keys).')
parser.add_argument(
'--dataset-type',
type=str,
help='The type of the dataset. If the checkpoint is converted '
'from other repository, this option is used to fill the dataset '
'meta information to the published checkpoint, like "ImageNet", '
'"CIFAR10" and others.')
args = parser.parse_args()
return args
def process_checkpoint(in_file, out_file, args):
checkpoint = torch.load(in_file, map_location='cpu')
# remove unnecessary fields for smaller file size
for key in ['optimizer', 'param_schedulers', 'hook_msgs', 'message_hub']:
checkpoint.pop(key, None)
# For checkpoint converted from the official weight
if 'state_dict' not in checkpoint:
checkpoint = dict(state_dict=checkpoint)
meta = checkpoint.get('meta', {})
meta.setdefault('mmpretrain_version', mmpretrain.__version__)
# handle dataset meta information
if args.dataset_type is not None:
from mmpretrain.registry import DATASETS
dataset_class = DATASETS.get(args.dataset_type)
dataset_meta = getattr(dataset_class, 'METAINFO', {})
else:
dataset_meta = {}
meta.setdefault('dataset_meta', dataset_meta)
if len(meta['dataset_meta']) == 0:
warnings.warn('Missing dataset meta information.')
checkpoint['meta'] = meta
ema_state_dict = OrderedDict()
if 'ema_state_dict' in checkpoint:
for k, v in checkpoint['ema_state_dict'].items():
# The ema static dict has some extra fields
if k.startswith('module.'):
origin_k = k[len('module.'):]
assert origin_k in checkpoint['state_dict']
ema_state_dict[origin_k] = v
del checkpoint['ema_state_dict']
print('The input checkpoint has EMA weights, ', end='')
if args.no_ema:
# The values stored in `ema_state_dict` is original values.
print('and drop the EMA weights.')
assert ema_state_dict.keys() <= checkpoint['state_dict'].keys()
checkpoint['state_dict'].update(ema_state_dict)
else:
print('and use the EMA weights.')
temp_out_file = Path(out_file).with_name('temp_' + Path(out_file).name)
torch.save(checkpoint, temp_out_file)
with open(temp_out_file, 'rb') as f:
sha = hashlib.sha256(f.read()).hexdigest()[:8]
if out_file.endswith('.pth'):
out_file_name = out_file[:-4]
else:
out_file_name = out_file
current_date = datetime.datetime.now().strftime('%Y%m%d')
final_file = out_file_name + f'_{current_date}-{sha[:8]}.pth'
shutil.move(temp_out_file, final_file)
print(f'Successfully generated the publish-ckpt as {final_file}.')
def main():
args = parse_args()
out_dir = Path(args.out_file).parent
if not out_dir.exists():
raise ValueError(f'Directory {out_dir} does not exist, '
'please generate it manually.')
process_checkpoint(args.in_file, args.out_file, args)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
from copy import deepcopy
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_swin(ckpt):
new_ckpt = OrderedDict()
convert_mapping = dict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if 'attn_mask' in k:
continue
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('norm'):
new_v = v
new_k = k.replace('norm', 'norm3')
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
convert_mapping[k] = new_k
return new_ckpt, convert_mapping
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained RAM models to'
'MMPretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
visual_ckpt = OrderedDict()
for key in state_dict:
if key.startswith('visual_encoder.'):
new_key = key.replace('visual_encoder.', '')
visual_ckpt[new_key] = state_dict[key]
new_visual_ckpt, convert_mapping = convert_swin(visual_ckpt)
new_ckpt = deepcopy(state_dict)
for key in state_dict:
if key.startswith('visual_encoder.'):
if 'attn_mask' in key:
del new_ckpt[key]
continue
del new_ckpt[key]
old_key = key.replace('visual_encoder.', '')
new_ckpt[key.replace(old_key,
convert_mapping[old_key])] = deepcopy(
new_visual_ckpt[key.replace(
old_key,
convert_mapping[old_key]).replace(
'visual_encoder.', '')])
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(new_ckpt, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
from mmpretrain.apis import init_model
from mmpretrain.models.classifiers import ImageClassifier
def convert_classifier_to_deploy(model, checkpoint, save_path):
print('Converting...')
assert hasattr(model, 'backbone') and \
hasattr(model.backbone, 'switch_to_deploy'), \
'`model.backbone` must has method of "switch_to_deploy".' \
f' But {model.backbone.__class__} does not have.'
model.backbone.switch_to_deploy()
checkpoint['state_dict'] = model.state_dict()
torch.save(checkpoint, save_path)
print('Done! Save at path "{}"'.format(save_path))
def main():
parser = argparse.ArgumentParser(
description='Convert the parameters of the repvgg block '
'from training mode to deployment mode.')
parser.add_argument(
'config_path',
help='The path to the configuration file of the network '
'containing the repvgg block.')
parser.add_argument(
'checkpoint_path',
help='The path to the checkpoint file corresponding to the model.')
parser.add_argument(
'save_path',
help='The path where the converted checkpoint file is stored.')
args = parser.parse_args()
save_path = Path(args.save_path)
if save_path.suffix != '.pth' and save_path.suffix != '.tar':
print('The path should contain the name of the pth format file.')
exit()
save_path.parent.mkdir(parents=True, exist_ok=True)
model = init_model(
args.config_path, checkpoint=args.checkpoint_path, device='cpu')
assert isinstance(model, ImageClassifier), \
'`model` must be a `mmpretrain.classifiers.ImageClassifier` instance.'
checkpoint = torch.load(args.checkpoint_path)
convert_classifier_to_deploy(model, checkpoint, args.save_path)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
from pathlib import Path
import torch
def convert(src, dst):
print('Converting...')
blobs = torch.load(src, map_location='cpu')
converted_state_dict = OrderedDict()
for key in blobs:
splited_key = key.split('.')
print(splited_key)
splited_key = [
'backbone.stem' if i[:4] == 'stem' else i for i in splited_key
]
splited_key = [
'backbone.stages' if i[:6] == 'stages' else i for i in splited_key
]
splited_key = [
'backbone.transitions' if i[:11] == 'transitions' else i
for i in splited_key
]
splited_key = [
'backbone.stages.3.norm' if i[:4] == 'norm' else i
for i in splited_key
]
splited_key = [
'head.fc' if i[:4] == 'head' else i for i in splited_key
]
new_key = '.'.join(splited_key)
converted_state_dict[new_key] = blobs[key]
torch.save(converted_state_dict, dst)
print('Done!')
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
convert(args.src, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
from pathlib import Path
import torch
def convert(src, dst):
print('Converting...')
blobs = torch.load(src, map_location='cpu')
converted_state_dict = OrderedDict()
for key in blobs:
splited_key = key.split('.')
splited_key = ['norm' if i == 'bn' else i for i in splited_key]
splited_key = [
'branch_norm' if i == 'rbr_identity' else i for i in splited_key
]
splited_key = [
'branch_1x1' if i == 'rbr_1x1' else i for i in splited_key
]
splited_key = [
'branch_3x3' if i == 'rbr_dense' else i for i in splited_key
]
splited_key = [
'backbone.stem' if i[:6] == 'stage0' else i for i in splited_key
]
splited_key = [
'backbone.stage_' + i[5] if i[:5] == 'stage' else i
for i in splited_key
]
splited_key = ['se_layer' if i == 'se' else i for i in splited_key]
splited_key = ['conv1.conv' if i == 'down' else i for i in splited_key]
splited_key = ['conv2.conv' if i == 'up' else i for i in splited_key]
splited_key = ['head.fc' if i == 'linear' else i for i in splited_key]
new_key = '.'.join(splited_key)
converted_state_dict[new_key] = blobs[key]
torch.save(converted_state_dict, dst)
print('Done!')
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
convert(args.src, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_revvit(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head.projection'):
new_k = k.replace('head.projection', 'head.fc')
new_ckpt[new_k] = new_v
continue
elif k.startswith('patch_embed'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('rev_backbone'):
new_k = k.replace('rev_backbone.', '')
if 'F.norm' in k:
new_k = new_k.replace('F.norm', 'ln1')
elif 'G.norm' in k:
new_k = new_k.replace('G.norm', 'ln2')
elif 'F.attn' in k:
new_k = new_k.replace('F.attn', 'attn')
elif 'G.mlp.fc1' in k:
new_k = new_k.replace('G.mlp.fc1', 'ffn.layers.0.0')
elif 'G.mlp.fc2' in k:
new_k = new_k.replace('G.mlp.fc2', 'ffn.layers.1')
elif k.startswith('norm'):
new_k = k.replace('norm', 'ln1')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
tmp_weight_dir = []
tmp_bias_dir = []
final_ckpt = OrderedDict()
for k, v in list(new_ckpt.items()):
if 'attn.q.weight' in k:
tmp_weight_dir.append(v)
elif 'attn.k.weight' in k:
tmp_weight_dir.append(v)
elif 'attn.v.weight' in k:
tmp_weight_dir.append(v)
new_k = k.replace('attn.v.weight', 'attn.qkv.weight')
final_ckpt[new_k] = torch.cat(tmp_weight_dir, dim=0)
tmp_weight_dir = []
elif 'attn.q.bias' in k:
tmp_bias_dir.append(v)
elif 'attn.k.bias' in k:
tmp_bias_dir.append(v)
elif 'attn.v.bias' in k:
tmp_bias_dir.append(v)
new_k = k.replace('attn.v.bias', 'attn.qkv.bias')
final_ckpt[new_k] = torch.cat(tmp_bias_dir, dim=0)
tmp_bias_dir = []
else:
final_ckpt[k] = v
return final_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained revvit'
' models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model_state' in checkpoint:
state_dict = checkpoint['model_state']
else:
state_dict = checkpoint
weight = convert_revvit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import torch
def convert_conv1(model_key, model_weight, state_dict, converted_names):
if model_key.find('conv1.0') >= 0:
new_key = model_key.replace('conv1.0', 'backbone.conv1.conv')
else:
new_key = model_key.replace('conv1.1', 'backbone.conv1.bn')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_conv5(model_key, model_weight, state_dict, converted_names):
if model_key.find('conv5.0') >= 0:
new_key = model_key.replace('conv5.0', 'backbone.layers.3.conv')
else:
new_key = model_key.replace('conv5.1', 'backbone.layers.3.bn')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_head(model_key, model_weight, state_dict, converted_names):
new_key = model_key.replace('fc', 'head.fc')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_block(model_key, model_weight, state_dict, converted_names):
split_keys = model_key.split('.')
layer, block, branch = split_keys[:3]
layer_id = int(layer[-1]) - 2
new_key = model_key.replace(layer, f'backbone.layers.{layer_id}')
if branch == 'branch1':
if new_key.find('branch1.0') >= 0:
new_key = new_key.replace('branch1.0', 'branch1.0.conv')
elif new_key.find('branch1.1') >= 0:
new_key = new_key.replace('branch1.1', 'branch1.0.bn')
elif new_key.find('branch1.2') >= 0:
new_key = new_key.replace('branch1.2', 'branch1.1.conv')
elif new_key.find('branch1.3') >= 0:
new_key = new_key.replace('branch1.3', 'branch1.1.bn')
elif branch == 'branch2':
if new_key.find('branch2.0') >= 0:
new_key = new_key.replace('branch2.0', 'branch2.0.conv')
elif new_key.find('branch2.1') >= 0:
new_key = new_key.replace('branch2.1', 'branch2.0.bn')
elif new_key.find('branch2.3') >= 0:
new_key = new_key.replace('branch2.3', 'branch2.1.conv')
elif new_key.find('branch2.4') >= 0:
new_key = new_key.replace('branch2.4', 'branch2.1.bn')
elif new_key.find('branch2.5') >= 0:
new_key = new_key.replace('branch2.5', 'branch2.2.conv')
elif new_key.find('branch2.6') >= 0:
new_key = new_key.replace('branch2.6', 'branch2.2.bn')
else:
raise ValueError(f'Unsupported conversion of key {model_key}')
else:
raise ValueError(f'Unsupported conversion of key {model_key}')
print(f'Convert {model_key} to {new_key}')
state_dict[new_key] = model_weight
converted_names.add(model_key)
def convert(src, dst):
"""Convert keys in torchvision pretrained ShuffleNetV2 models to mmpretrain
style."""
# load pytorch model
blobs = torch.load(src, map_location='cpu')
# convert to pytorch style
state_dict = OrderedDict()
converted_names = set()
for key, weight in blobs.items():
if 'conv1' in key:
convert_conv1(key, weight, state_dict, converted_names)
elif 'fc' in key:
convert_head(key, weight, state_dict, converted_names)
elif key.startswith('s'):
convert_block(key, weight, state_dict, converted_names)
elif 'conv5' in key:
convert_conv5(key, weight, state_dict, converted_names)
# check if all layers are converted
for key in blobs:
if key not in converted_names:
print(f'not converted: {key}')
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
def convert_weights(weight):
"""Weight Converter.
Converts the weights from timm to mmpretrain
Args:
weight (dict): weight dict from timm
Returns:
Converted weight dict for mmpretrain
"""
result = dict()
result['meta'] = dict()
temp = dict()
mapping = {
'c.weight': 'conv2d.weight',
'bn.weight': 'bn2d.weight',
'bn.bias': 'bn2d.bias',
'bn.running_mean': 'bn2d.running_mean',
'bn.running_var': 'bn2d.running_var',
'bn.num_batches_tracked': 'bn2d.num_batches_tracked',
'layers': 'stages',
'norm_head': 'norm3',
}
weight = weight['model']
for k, v in weight.items():
# keyword mapping
for mk, mv in mapping.items():
if mk in k:
k = k.replace(mk, mv)
if k.startswith('head.'):
temp['head.fc.' + k[5:]] = v
else:
temp['backbone.' + k] = v
result['state_dict'] = temp
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
original_model = torch.load(args.src, map_location='cpu')
converted_model = convert_weights(original_model)
torch.save(converted_model, args.dst)
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
from pathlib import Path
import torch
def convert_resnet(src_dict, dst_dict):
"""convert resnet checkpoints from torchvision."""
for key, value in src_dict.items():
if not key.startswith('fc'):
dst_dict['backbone.' + key] = value
else:
dst_dict['head.' + key] = value
# model name to convert function
CONVERT_F_DICT = {
'resnet': convert_resnet,
}
def convert(src: str, dst: str, convert_f: callable):
print('Converting...')
blobs = torch.load(src, map_location='cpu')
converted_state_dict = OrderedDict()
# convert key in weight
convert_f(blobs, converted_state_dict)
torch.save(converted_state_dict, dst)
print('Done!')
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
parser.add_argument(
'model', type=str, help='The algorithm needs to change the keys.')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
# this tool only support model in CONVERT_F_DICT
support_models = list(CONVERT_F_DICT.keys())
if args.model not in CONVERT_F_DICT:
print(f'The "{args.model}" has not been supported to convert now.')
print(f'This tool only supports {", ".join(support_models)}.')
print('If you have done the converting job, PR is welcome!')
exit(1)
convert_f = CONVERT_F_DICT[args.model]
convert(args.src, args.dst, convert_f)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_twins(args, ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('patch_embeds'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('blocks'):
k = k.replace('blocks', 'stages')
# Union
if 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
elif k.startswith('pos_block'):
new_k = k.replace('pos_block', 'position_encodings')
if 'proj.0.' in new_k:
new_k = new_k.replace('proj.0.', 'proj.')
elif k.startswith('norm'):
new_k = k.replace('norm', 'norm_after_stage3')
else:
new_k = k
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMPretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_twins(args, state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_van(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('patch_embed'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('block'):
new_k = k.replace('block', 'blocks')
if 'attn.spatial_gating_unit' in new_k:
new_k = new_k.replace('conv0', 'DW_conv')
new_k = new_k.replace('conv_spatial', 'DW_D_conv')
if 'dwconv.dwconv' in new_k:
new_k = new_k.replace('dwconv.dwconv', 'dwconv')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained van '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_van(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
from collections import OrderedDict
import torch
def get_layer_maps(layer_num, with_bn):
layer_maps = {'conv': {}, 'bn': {}}
if with_bn:
if layer_num == 11:
layer_idxs = [0, 4, 8, 11, 15, 18, 22, 25]
elif layer_num == 13:
layer_idxs = [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]
elif layer_num == 16:
layer_idxs = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
elif layer_num == 19:
layer_idxs = [
0, 3, 7, 10, 14, 17, 20, 23, 27, 30, 33, 36, 40, 43, 46, 49
]
else:
raise ValueError(f'Invalid number of layers: {layer_num}')
for i, layer_idx in enumerate(layer_idxs):
if i == 0:
new_layer_idx = layer_idx
else:
new_layer_idx += int((layer_idx - layer_idxs[i - 1]) / 2)
layer_maps['conv'][layer_idx] = new_layer_idx
layer_maps['bn'][layer_idx + 1] = new_layer_idx
else:
if layer_num == 11:
layer_idxs = [0, 3, 6, 8, 11, 13, 16, 18]
new_layer_idxs = [0, 2, 4, 5, 7, 8, 10, 11]
elif layer_num == 13:
layer_idxs = [0, 2, 5, 7, 10, 12, 15, 17, 20, 22]
new_layer_idxs = [0, 1, 3, 4, 6, 7, 9, 10, 12, 13]
elif layer_num == 16:
layer_idxs = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
new_layer_idxs = [0, 1, 3, 4, 6, 7, 8, 10, 11, 12, 14, 15, 16]
elif layer_num == 19:
layer_idxs = [
0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34
]
new_layer_idxs = [
0, 1, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19
]
else:
raise ValueError(f'Invalid number of layers: {layer_num}')
layer_maps['conv'] = {
layer_idx: new_layer_idx
for layer_idx, new_layer_idx in zip(layer_idxs, new_layer_idxs)
}
return layer_maps
def convert(src, dst, layer_num, with_bn=False):
"""Convert keys in torchvision pretrained VGG models to mmpretrain
style."""
# load pytorch model
assert os.path.isfile(src), f'no checkpoint found at {src}'
blobs = torch.load(src, map_location='cpu')
# convert to pytorch style
state_dict = OrderedDict()
layer_maps = get_layer_maps(layer_num, with_bn)
prefix = 'backbone'
delimiter = '.'
for key, weight in blobs.items():
if 'features' in key:
module, layer_idx, weight_type = key.split(delimiter)
new_key = delimiter.join([prefix, key])
layer_idx = int(layer_idx)
for layer_key, maps in layer_maps.items():
if layer_idx in maps:
new_layer_idx = maps[layer_idx]
new_key = delimiter.join([
prefix, 'features',
str(new_layer_idx), layer_key, weight_type
])
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
elif 'classifier' in key:
new_key = delimiter.join([prefix, key])
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
else:
state_dict[key] = weight
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src torchvision model path')
parser.add_argument('dst', help='save path')
parser.add_argument(
'--bn', action='store_true', help='whether original vgg has BN')
parser.add_argument(
'--layer-num',
type=int,
choices=[11, 13, 16, 19],
default=11,
help='number of VGG layers')
args = parser.parse_args()
convert(args.src, args.dst, layer_num=args.layer_num, with_bn=args.bn)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import re
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_vig(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
new_key = k
new_value = v
if 'pos_embed' in new_key:
new_key = new_key.replace('pos_embed', 'backbone.pos_embed')
elif 'stem' in new_key:
new_key = new_key.replace('stem.convs', 'backbone.stem')
elif 'backbone' in new_key:
new_key = new_key.replace('backbone', 'backbone.blocks')
elif 'prediction.0' in new_key:
new_key = new_key.replace('prediction.0', 'head.fc1')
new_value = v.squeeze(-1).squeeze(-1)
elif 'prediction.1' in new_key:
new_key = new_key.replace('prediction.1', 'head.bn')
elif 'prediction.4' in new_key:
new_key = new_key.replace('prediction.4', 'head.fc2')
new_value = v.squeeze(-1).squeeze(-1)
new_ckpt[new_key] = new_value
return new_ckpt
def convert_pvig(ckpt):
new_ckpt = OrderedDict()
stage_idx = 0
stage_blocks = 0
for k, v in ckpt.items():
new_key: str = k
new_value = v
if 'pos_embed' in new_key:
new_key = new_key.replace('pos_embed', 'backbone.pos_embed')
elif 'stem' in new_key:
new_key = new_key.replace('stem.convs', 'backbone.stem')
elif re.match(r'^backbone\.\d+\.conv', new_key) is not None:
if new_key.endswith('0.weight'):
stage_idx += 1
stage_blocks = int(new_key.split('.')[1])
other = new_key.split('.', maxsplit=3)[-1]
new_key = f'backbone.stages.{stage_idx}.0.' + other
elif 'backbone' in new_key:
block_idx = int(new_key.split('.')[1]) - stage_blocks
other = new_key.split('.', maxsplit=2)[-1]
new_key = f'backbone.stages.{stage_idx}.{block_idx}.' + other
elif 'prediction.0' in new_key:
new_key = new_key.replace('prediction.0', 'head.fc1')
new_value = v.squeeze(-1).squeeze(-1)
elif 'prediction.1' in new_key:
new_key = new_key.replace('prediction.1', 'head.bn')
elif 'prediction.4' in new_key:
new_key = new_key.replace('prediction.4', 'head.fc2')
new_value = v.squeeze(-1).squeeze(-1)
new_ckpt[new_key] = new_value
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained vig models to '
'mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if 'backbone.2.conv.0.weight' in state_dict:
weight = convert_pvig(state_dict)
else:
weight = convert_vig(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment