Commit 1ac2e802 authored by limm's avatar limm
Browse files

add tools code

parent b6df0d33
Pipeline #2803 canceled with stages
#!/usr/bin/env bash
set -x
DOWNLOAD_DIR=$1
DATA_ROOT=$2
# unzip all of data
cat $DOWNLOAD_DIR/CUB-200-2011/raw/*.tar.gz | tar -xvz -C $DOWNLOAD_DIR
# move data into DATA_ROOT
mv -f $DOWNLOAD_DIR/CUB-200-2011/CUB-200-2011/* $DATA_ROOT/
# remove useless data file
rm -R $DOWNLOAD_DIR/CUB-200-2011/
#!/usr/bin/env bash
set -x
DOWNLOAD_DIR=$1
DATA_ROOT=$2
# unzip all of data
cat $DOWNLOAD_DIR/ImageNet-1K/raw/*.tar.gz.* | tar -xvz -C $DOWNLOAD_DIR
# move images into data/imagenet
mv $DOWNLOAD_DIR/ImageNet-1K/{train,val,test} $DATA_ROOT
# download the mate ann_files file
wget -P $DATA_ROOT https://download.openmmlab.com/mmclassification/datasets/imagenet/meta/caffe_ilsvrc12.tar.gz
# unzip mate ann_files file and put it into 'meta' folder
mkdir $DATA_ROOT/meta
tar -xzvf $DATA_ROOT/caffe_ilsvrc12.tar.gz -C $DATA_ROOT/meta
# remove useless data files
rm -R $DOWNLOAD_DIR/ImageNet-1K
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/test.py \
$CONFIG \
$CHECKPOINT \
--launcher pytorch \
${@:4}
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--launcher pytorch ${@:3}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.dist import sync_random_seed
from mmengine.fileio import dump, load
from mmengine.hooks import Hook
from mmengine.runner import Runner, find_latest_checkpoint
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
EXP_INFO_FILE = 'kfold_exp.json'
prog_description = """K-Fold cross-validation.
To start a 5-fold cross-validation experiment:
python tools/kfold-cross-valid.py $CONFIG --num-splits 5
To resume a 5-fold cross-validation from an interrupted experiment:
python tools/kfold-cross-valid.py $CONFIG --num-splits 5 --resume
""" # noqa: E501
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=prog_description)
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--num-splits',
type=int,
help='The number of all folds.',
required=True)
parser.add_argument(
'--fold',
type=int,
help='The fold used to do validation. '
'If specify, only do an experiment of the specified fold.')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--resume',
action='store_true',
help='Resume the previous experiment.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision training')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--no-pin-memory',
action='store_true',
help='whether to disable the pin_memory option in dataloaders.')
parser.add_argument(
'--no-persistent-workers',
action='store_true',
help='whether to disable the persistent_workers option in dataloaders.'
)
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
if args.no_validate:
cfg.val_cfg = None
cfg.val_dataloader = None
cfg.val_evaluator = None
cfg.launcher = args.launcher
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
'`--amp` is not supported custom optimizer wrapper type ' \
f'`{optim_wrapper}.'
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
# enable auto scale learning rate
if args.auto_scale_lr:
cfg.auto_scale_lr.enable = True
# set dataloader args
default_dataloader_cfg = ConfigDict(
pin_memory=True,
persistent_workers=True,
collate_fn=dict(type='default_collate'),
)
if digit_version(TORCH_VERSION) < digit_version('1.8.0'):
default_dataloader_cfg.persistent_workers = False
def set_default_dataloader_cfg(cfg, field):
if cfg.get(field, None) is None:
return
dataloader_cfg = copy.deepcopy(default_dataloader_cfg)
dataloader_cfg.update(cfg[field])
cfg[field] = dataloader_cfg
if args.no_pin_memory:
cfg[field]['pin_memory'] = False
if args.no_persistent_workers:
cfg[field]['persistent_workers'] = False
set_default_dataloader_cfg(cfg, 'train_dataloader')
set_default_dataloader_cfg(cfg, 'val_dataloader')
set_default_dataloader_cfg(cfg, 'test_dataloader')
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
return cfg
def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
root_dir = cfg.work_dir
cfg.work_dir = osp.join(root_dir, f'fold{fold}')
if resume_ckpt is not None:
cfg.resume = True
cfg.load_from = resume_ckpt
dataset = cfg.train_dataloader.dataset
# wrap the dataset cfg
def wrap_dataset(dataset, test_mode):
return dict(
type='KFoldDataset',
dataset=dataset,
fold=fold,
num_splits=num_splits,
seed=cfg.kfold_split_seed,
test_mode=test_mode,
)
train_dataset = copy.deepcopy(dataset)
cfg.train_dataloader.dataset = wrap_dataset(train_dataset, False)
if cfg.val_dataloader is not None:
if 'pipeline' not in cfg.val_dataloader.dataset:
raise ValueError(
'Cannot find `pipeline` in the validation dataset. '
"If you are using dataset wrapper, please don't use this "
'tool to act kfold cross validation. '
'Please write config files manually.')
val_dataset = copy.deepcopy(dataset)
val_dataset['pipeline'] = cfg.val_dataloader.dataset.pipeline
cfg.val_dataloader.dataset = wrap_dataset(val_dataset, True)
if cfg.test_dataloader is not None:
if 'pipeline' not in cfg.test_dataloader.dataset:
raise ValueError(
'Cannot find `pipeline` in the test dataset. '
"If you are using dataset wrapper, please don't use this "
'tool to act kfold cross validation. '
'Please write config files manually.')
test_dataset = copy.deepcopy(dataset)
test_dataset['pipeline'] = cfg.test_dataloader.dataset.pipeline
cfg.test_dataloader.dataset = wrap_dataset(test_dataset, True)
# build the runner from config
runner = Runner.from_cfg(cfg)
runner.logger.info(
f'----------- Cross-validation: [{fold+1}/{num_splits}] ----------- ')
runner.logger.info(f'Train dataset: \n{runner.train_dataloader.dataset}')
class SaveInfoHook(Hook):
def after_train_epoch(self, runner):
last_ckpt = find_latest_checkpoint(cfg.work_dir)
exp_info = dict(
fold=fold,
last_ckpt=last_ckpt,
kfold_split_seed=cfg.kfold_split_seed,
)
dump(exp_info, osp.join(root_dir, EXP_INFO_FILE))
runner.register_hook(SaveInfoHook(), 'LOWEST')
# start training
runner.train()
def main():
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
# merge cli arguments to config
cfg = merge_args(cfg, args)
# set the unify random seed
cfg.kfold_split_seed = args.seed or sync_random_seed()
# resume from the previous experiment
if args.resume:
experiment_info = load(osp.join(cfg.work_dir, EXP_INFO_FILE))
resume_fold = experiment_info['fold']
cfg.kfold_split_seed = experiment_info['kfold_split_seed']
resume_ckpt = experiment_info.get('last_ckpt', None)
else:
resume_fold = 0
resume_ckpt = None
if args.fold is not None:
folds = [args.fold]
else:
folds = range(resume_fold, args.num_splits)
for fold in folds:
cfg_ = copy.deepcopy(cfg)
train_single_fold(cfg_, args.num_splits, fold, resume_ckpt)
resume_ckpt = None
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import rich.console
from mmengine import Config, DictAction
console = rich.console.Console()
def parse_args():
parser = argparse.ArgumentParser(description='Print the whole config')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
console.print(cfg.pretty_text, markup=False)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import multiprocessing
import os
from pathlib import Path
from mmengine import (Config, DictAction, track_parallel_progress,
track_progress)
from mmpretrain.datasets import build_dataset
from mmpretrain.registry import TRANSFORMS
file_lock = multiprocessing.Lock()
def parse_args():
parser = argparse.ArgumentParser(description='Verify Dataset')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--out-path',
type=str,
default='brokenfiles.log',
help='output path of all the broken files. If the specified path '
'already exists, delete the previous file ')
parser.add_argument(
'--phase',
default='train',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to visualize, accept "train" "test" and "val".')
parser.add_argument(
'--num-process', type=int, default=1, help='number of process to use')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
assert args.out_path is not None
assert args.num_process > 0
return args
class DatasetValidator():
"""the dataset tool class to check if all file are broken."""
def __init__(self, dataset_cfg, log_file_path):
super(DatasetValidator, self).__init__()
# keep only LoadImageFromFile pipeline
from mmpretrain.datasets import get_transform_idx
load_idx = get_transform_idx(dataset_cfg.pipeline, 'LoadImageFromFile')
assert load_idx >= 0, \
'This tool is only for datasets needs to load image from files.'
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[load_idx])
dataset_cfg.pipeline = []
dataset = build_dataset(dataset_cfg)
self.dataset = dataset
self.log_file_path = log_file_path
def valid_idx(self, idx):
item = self.dataset[idx]
try:
item = self.pipeline(item)
except Exception:
with open(self.log_file_path, 'a') as f:
# add file lock to prevent multi-process writing errors
filepath = str(Path(item['img_path']))
file_lock.acquire()
f.write(filepath + '\n')
file_lock.release()
print(f'{filepath} cannot be read correctly, please check it.')
def __len__(self):
return len(self.dataset)
def print_info(log_file_path):
"""print some information and do extra action."""
print()
with open(log_file_path, 'r') as f:
content = f.read().strip()
if content == '':
print('There is no broken file found.')
os.remove(log_file_path)
else:
num_file = len(content.split('\n'))
print(f'{num_file} broken files found, name list save in file:'
f'{log_file_path}')
print()
def main():
# parse cfg and args
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# touch output file to save broken files list.
output_path = Path(args.out_path)
if not output_path.parent.exists():
raise Exception("Path '--out-path' parent directory not found.")
if output_path.exists():
os.remove(output_path)
output_path.touch()
if args.phase == 'train':
dataset_cfg = cfg.train_dataloader.dataset
elif args.phase == 'val':
dataset_cfg = cfg.val_dataloader.dataset
elif args.phase == 'test':
dataset_cfg = cfg.test_dataloader.dataset
else:
raise ValueError("'--phase' only support 'train', 'val' and 'test'.")
# do validate
validator = DatasetValidator(dataset_cfg, output_path)
if args.num_process > 1:
# The default chunksize calcuation method of Pool.map
chunksize, extra = divmod(len(validator), args.num_process * 8)
if extra:
chunksize += 1
track_parallel_progress(
validator.valid_idx,
list(range(len(validator))),
args.num_process,
chunksize=chunksize,
keep_order=False)
else:
track_progress(validator.valid_idx, list(range(len(validator))))
print_info(output_path)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_clip(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.layers.head.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('patch_embed'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('norm_pre'):
new_k = k.replace('norm_pre', 'pre_norm')
elif k.startswith('blocks'):
new_k = k.replace('blocks.', 'layers.')
if 'norm1' in k:
new_k = new_k.replace('norm1', 'ln1')
elif 'norm2' in k:
new_k = new_k.replace('norm2', 'ln2')
elif 'mlp.fc1' in k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
elif k.startswith('norm'):
new_k = k.replace('norm', 'ln1')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained clip '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_clip(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_convnext(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('stages'):
if 'dwconv' in k:
new_k = k.replace('dwconv', 'depthwise_conv')
elif 'pwconv' in k:
new_k = k.replace('pwconv', 'pointwise_conv')
else:
new_k = k
elif k.startswith('norm'):
new_k = k.replace('norm', 'norm3')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained convnext '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_convnext(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(dict(state_dict=weight), args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_davit(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('patch_embeds.0'):
new_k = k.replace('patch_embeds.0', 'patch_embed')
new_k = new_k.replace('proj', 'projection')
elif k.startswith('patch_embeds'):
if k.startswith('patch_embeds.1'):
new_k = k.replace('patch_embeds.1', 'stages.0.downsample')
elif k.startswith('patch_embeds.2'):
new_k = k.replace('patch_embeds.2', 'stages.1.downsample')
elif k.startswith('patch_embeds.3'):
new_k = k.replace('patch_embeds.3', 'stages.2.downsample')
new_k = new_k.replace('proj', 'projection')
elif k.startswith('main_blocks'):
new_k = k.replace('main_blocks', 'stages')
for num_stages in range(4):
for num_blocks in range(9):
if f'{num_stages}.{num_blocks}.0' in k:
new_k = new_k.replace(
f'{num_stages}.{num_blocks}.0',
f'{num_stages}.blocks.{num_blocks}.spatial_block')
elif f'{num_stages}.{num_blocks}.1' in k:
new_k = new_k.replace(
f'{num_stages}.{num_blocks}.1',
f'{num_stages}.blocks.{num_blocks}.channel_block')
if 'cpe.0' in k:
new_k = new_k.replace('cpe.0', 'cpe1')
elif 'cpe.1' in k:
new_k = new_k.replace('cpe.1', 'cpe2')
if 'mlp' in k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
if 'spatial_block.attn' in new_k:
new_k = new_k.replace('spatial_block.attn',
'spatial_block.attn.w_msa')
elif k.startswith('norms'):
new_k = k.replace('norms', 'norm3')
elif k.startswith('head'):
new_k = k.replace('head', 'head.fc')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained davit '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_davit(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_deit3(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.layers.head.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('patch_embed'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('blocks'):
new_k = k.replace('blocks.', 'layers.')
if 'norm1' in k:
new_k = new_k.replace('norm1', 'ln1')
elif 'norm2' in k:
new_k = new_k.replace('norm2', 'ln2')
elif 'mlp.fc1' in k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
elif 'gamma_1' in k:
new_k = new_k.replace('gamma_1', 'attn.gamma1.weight')
elif 'gamma_2' in k:
new_k = new_k.replace('gamma_2', 'ffn.gamma2.weight')
elif k.startswith('norm'):
new_k = k.replace('norm', 'ln1')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained deit3 '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_deit3(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
def convert_weights(weight):
"""Weight Converter.
Converts the weights from timm to mmpretrain
Args:
weight (dict): weight dict from timm
Returns:
Converted weight dict for mmpretrain
"""
result = dict()
result['meta'] = dict()
temp = dict()
mapping = {
'dwconv': 'depthwise_conv',
'pwconv1': 'pointwise_conv1',
'pwconv2': 'pointwise_conv2',
'xca': 'csa',
'convs': 'conv_modules',
'token_projection': 'proj',
'pos_embd': 'pos_embed',
'temperature': 'scale',
}
strict_mapping = {
'norm.weight': 'norm3.weight',
'norm.bias': 'norm3.bias',
}
try:
weight = weight['model_ema']
except KeyError:
weight = weight['state_dict'] # for model learned with usi
else:
raise NotImplementedError
for k, v in weight.items():
# keyword mapping
for mk, mv in mapping.items():
if mk in k:
k = k.replace(mk, mv)
# strict mapping
for mk, mv in strict_mapping.items():
if mk == k:
k = mv
if k.startswith('head.'):
temp['head.fc.' + k[5:]] = v
else:
temp['backbone.' + k] = v
result['state_dict'] = temp
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
original_model = torch.load(args.src, map_location='cpu')
converted_model = convert_weights(original_model)
torch.save(converted_model, args.dst)
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import numpy as np
import torch
from mmengine.model import Sequential
from tensorflow.python.training import py_checkpoint_reader
from mmpretrain.models.backbones.efficientnet import EfficientNet
def tf2pth(v):
if v.ndim == 4:
return np.ascontiguousarray(v.transpose(3, 2, 0, 1))
elif v.ndim == 2:
return np.ascontiguousarray(v.transpose())
return v
def read_ckpt(ckpt):
reader = py_checkpoint_reader.NewCheckpointReader(ckpt)
weights = {
n: torch.as_tensor(tf2pth(reader.get_tensor(n)))
for (n, _) in reader.get_variable_to_shape_map().items()
}
return weights
def map_key(weight, l2_flag):
m = dict()
has_expand_conv = set()
is_MBConv = set()
max_idx = 0
name = None
for k, v in weight.items():
seg = k.split('/')
if len(seg) == 1:
continue
if 'edgetpu' in seg[0]:
name = 'e' + seg[0][21:].lower()
else:
name = seg[0][13:]
if seg[2] == 'tpu_batch_normalization_2':
has_expand_conv.add(seg[1])
if seg[1].startswith('blocks_'):
idx = int(seg[1][7:]) + 1
max_idx = max(max_idx, idx)
if 'depthwise' in k:
is_MBConv.add(seg[1])
model = EfficientNet(name)
idx2key = []
for idx, module in enumerate(model.layers):
if isinstance(module, Sequential):
for j in range(len(module)):
idx2key.append('{}.{}'.format(idx, j))
else:
idx2key.append('{}'.format(idx))
for k, v in weight.items():
if l2_flag:
k = k.replace('/ExponentialMovingAverage', '')
if 'Exponential' in k or 'RMS' in k:
continue
seg = k.split('/')
if len(seg) == 1:
continue
if seg[2] == 'depthwise_conv2d':
v = v.transpose(1, 0)
if seg[1] == 'stem':
prefix = 'backbone.layers.{}'.format(idx2key[0])
mapping = {
'conv2d/kernel': 'conv.weight',
'tpu_batch_normalization/beta': 'bn.bias',
'tpu_batch_normalization/gamma': 'bn.weight',
'tpu_batch_normalization/moving_mean': 'bn.running_mean',
'tpu_batch_normalization/moving_variance': 'bn.running_var',
}
suffix = mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
elif seg[1].startswith('blocks_'):
idx = int(seg[1][7:]) + 1
prefix = '.'.join(['backbone', 'layers', idx2key[idx]])
if seg[1] not in is_MBConv:
mapping = {
'conv2d/kernel':
'conv1.conv.weight',
'tpu_batch_normalization/gamma':
'conv1.bn.weight',
'tpu_batch_normalization/beta':
'conv1.bn.bias',
'tpu_batch_normalization/moving_mean':
'conv1.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'conv1.bn.running_var',
'conv2d_1/kernel':
'conv2.conv.weight',
'tpu_batch_normalization_1/gamma':
'conv2.bn.weight',
'tpu_batch_normalization_1/beta':
'conv2.bn.bias',
'tpu_batch_normalization_1/moving_mean':
'conv2.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'conv2.bn.running_var',
}
else:
base_mapping = {
'depthwise_conv2d/depthwise_kernel':
'depthwise_conv.conv.weight',
'se/conv2d/kernel': 'se.conv1.conv.weight',
'se/conv2d/bias': 'se.conv1.conv.bias',
'se/conv2d_1/kernel': 'se.conv2.conv.weight',
'se/conv2d_1/bias': 'se.conv2.conv.bias'
}
if seg[1] not in has_expand_conv:
mapping = {
'conv2d/kernel':
'linear_conv.conv.weight',
'tpu_batch_normalization/beta':
'depthwise_conv.bn.bias',
'tpu_batch_normalization/gamma':
'depthwise_conv.bn.weight',
'tpu_batch_normalization/moving_mean':
'depthwise_conv.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'depthwise_conv.bn.running_var',
'tpu_batch_normalization_1/beta':
'linear_conv.bn.bias',
'tpu_batch_normalization_1/gamma':
'linear_conv.bn.weight',
'tpu_batch_normalization_1/moving_mean':
'linear_conv.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'linear_conv.bn.running_var',
}
else:
mapping = {
'depthwise_conv2d/depthwise_kernel':
'depthwise_conv.conv.weight',
'conv2d/kernel':
'expand_conv.conv.weight',
'conv2d_1/kernel':
'linear_conv.conv.weight',
'tpu_batch_normalization/beta':
'expand_conv.bn.bias',
'tpu_batch_normalization/gamma':
'expand_conv.bn.weight',
'tpu_batch_normalization/moving_mean':
'expand_conv.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'expand_conv.bn.running_var',
'tpu_batch_normalization_1/beta':
'depthwise_conv.bn.bias',
'tpu_batch_normalization_1/gamma':
'depthwise_conv.bn.weight',
'tpu_batch_normalization_1/moving_mean':
'depthwise_conv.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'depthwise_conv.bn.running_var',
'tpu_batch_normalization_2/beta':
'linear_conv.bn.bias',
'tpu_batch_normalization_2/gamma':
'linear_conv.bn.weight',
'tpu_batch_normalization_2/moving_mean':
'linear_conv.bn.running_mean',
'tpu_batch_normalization_2/moving_variance':
'linear_conv.bn.running_var',
}
mapping.update(base_mapping)
suffix = mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
elif seg[1] == 'head':
seq_key = idx2key[max_idx + 1]
mapping = {
'conv2d/kernel':
'backbone.layers.{}.conv.weight'.format(seq_key),
'tpu_batch_normalization/beta':
'backbone.layers.{}.bn.bias'.format(seq_key),
'tpu_batch_normalization/gamma':
'backbone.layers.{}.bn.weight'.format(seq_key),
'tpu_batch_normalization/moving_mean':
'backbone.layers.{}.bn.running_mean'.format(seq_key),
'tpu_batch_normalization/moving_variance':
'backbone.layers.{}.bn.running_var'.format(seq_key),
'dense/kernel':
'head.fc.weight',
'dense/bias':
'head.fc.bias'
}
key = mapping['/'.join(seg[2:])]
if name.startswith('e') and 'fc' in key:
v = v[1:]
m[key] = v
return m
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('infile', type=str, help='Path to the ckpt.')
parser.add_argument('outfile', type=str, help='Output file.')
parser.add_argument(
'--l2',
action='store_true',
help='If true convert ExponentialMovingAverage weights. '
'l2 arch should use it.')
args = parser.parse_args()
assert args.outfile
outdir = os.path.dirname(os.path.abspath(args.outfile))
if not os.path.exists(outdir):
os.makedirs(outdir)
weights = read_ckpt(args.infile)
weights = map_key(weights, args.l2)
torch.save(weights, args.outfile)
# Copyright (c) OpenMMLab. All rights reserved.
"""convert the weights of efficientnetv2 in
timm(https://github.com/rwightman/pytorch-image-models) to mmpretrain
format."""
import argparse
import os.path as osp
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_from_efficientnetv2_timm(param):
# main change_key
param_lst = list(param.keys())
op = str(int(param_lst[-9][7]) + 2)
new_key = dict()
for name in param_lst:
data = param[name]
if 'blocks' not in name:
if 'conv_stem' in name:
name = name.replace('conv_stem', 'backbone.layers.0.conv')
if 'bn1' in name:
name = name.replace('bn1', 'backbone.layers.0.bn')
if 'conv_head' in name:
# if efficientnet-v2_s/base/b1/b2/b3,op = 7,
# if for m/l/xl , op = 8
name = name.replace('conv_head', f'backbone.layers.{op}.conv')
if 'bn2' in name:
name = name.replace('bn2', f'backbone.layers.{op}.bn')
if 'classifier' in name:
name = name.replace('classifier', 'head.fc')
else:
operator = int(name[7])
if operator == 0:
name = name[:7] + str(operator + 1) + name[8:]
name = name.replace('blocks', 'backbone.layers')
if 'conv' in name:
name = name.replace('conv', 'conv')
if 'bn1' in name:
name = name.replace('bn1', 'bn')
elif operator < 3:
name = name[:7] + str(operator + 1) + name[8:]
name = name.replace('blocks', 'backbone.layers')
if 'conv_exp' in name:
name = name.replace('conv_exp', 'conv1.conv')
if 'conv_pwl' in name:
name = name.replace('conv_pwl', 'conv2.conv')
if 'bn1' in name:
name = name.replace('bn1', 'conv1.bn')
if 'bn2' in name:
name = name.replace('bn2', 'conv2.bn')
else:
name = name[:7] + str(operator + 1) + name[8:]
name = name.replace('blocks', 'backbone.layers')
if 'conv_pwl' in name:
name = name.replace('conv_pwl', 'linear_conv.conv')
if 'conv_pw' in name:
name = name.replace('conv_pw', 'expand_conv.conv')
if 'conv_dw' in name:
name = name.replace('conv_dw', 'depthwise_conv.conv')
if 'bn1' in name:
name = name.replace('bn1', 'expand_conv.bn')
if 'bn2' in name:
name = name.replace('bn2', 'depthwise_conv.bn')
if 'bn3' in name:
name = name.replace('bn3', 'linear_conv.bn')
if 'se.conv_reduce' in name:
name = name.replace('se.conv_reduce', 'se.conv1.conv')
if 'se.conv_expand' in name:
name = name.replace('se.conv_expand', 'se.conv2.conv')
new_key[name] = data
return new_key
def main():
parser = argparse.ArgumentParser(
description='Convert pretrained efficientnetv2 '
'models in timm to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_from_efficientnetv2_timm(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_eva02(ckpt):
new_ckpt = OrderedDict()
qkv_proj = {}
qkv_bias = {}
w12_weight = {}
w12_bias = {}
banned = {
'mask_token',
'lm_head.weight',
'lm_head.bias',
'norm.weight',
'norm.bias',
}
for k, v in list(ckpt.items()):
if k in banned:
continue
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = v
else:
if k.startswith('patch_embed'):
new_k = k.replace('proj.', 'projection.')
elif k.startswith('fc_norm') or k.startswith('norm'):
new_k = k.replace('norm.', 'ln2.')
new_k = k.replace('fc_norm.', 'ln2.')
elif k.startswith('blocks'):
new_k = k.replace('blocks.', 'layers.')
if 'mlp' in new_k:
if 'w1.' in new_k or 'w2.' in new_k:
# For base and large version, mlp is implemented with
# 2 linears, where w1 and w2 are required to integrate
# into w12.
s = new_k.split('.') # e.g. layers.0.mlp.w1.weight
idx = s[1]
if 'weight' in new_k:
# w1.weight or w2.weight
if idx not in w12_weight:
w12_weight[idx] = {}
w12_weight[idx][s[-2]] = v
else:
# w1.bias or w2.bias
if idx not in w12_bias:
w12_bias[idx] = {}
w12_bias[idx][s[-2]] = v
continue
if 'ffn_ln' in new_k:
new_k = new_k.replace('ffn_ln.', 'norm.')
elif 'attn' in new_k:
if 'q_proj.weight' in new_k or \
'k_proj.weight' in new_k or \
'v_proj.weight' in new_k:
# For base and large version, qkv projection is
# implemented with three linear layers,
s = new_k.split('.')
idx = s[1]
if idx not in qkv_proj:
qkv_proj[idx] = {}
qkv_proj[idx][s[-2]] = v
continue
if 'q_bias' in new_k or 'v_bias' in new_k:
# k_bias is 0
s = new_k.split('.')
idx = s[1]
if idx not in qkv_bias:
qkv_bias[idx] = {}
qkv_bias[idx][s[-1]] = v
continue
else:
new_k = k
new_k = 'backbone.' + new_k
new_ckpt[new_k] = v
for idx in qkv_proj:
q_proj = qkv_proj[idx]['q_proj']
k_proj = qkv_proj[idx]['k_proj']
v_proj = qkv_proj[idx]['v_proj']
weight = torch.cat((q_proj, k_proj, v_proj))
new_k = f'backbone.layers.{idx}.attn.qkv.weight'
new_ckpt[new_k] = weight
for idx in qkv_bias:
q_bias = qkv_bias[idx]['q_bias']
k_bias = torch.zeros_like(q_bias)
v_bias = qkv_bias[idx]['v_bias']
weight = torch.cat((q_bias, k_bias, v_bias))
new_k = f'backbone.layers.{idx}.attn.qkv.bias'
new_ckpt[new_k] = weight
for idx in w12_weight:
w1 = w12_weight[idx]['w1']
w2 = w12_weight[idx]['w2']
weight = torch.cat((w1, w2))
new_k = f'backbone.layers.{idx}.mlp.w12.weight'
new_ckpt[new_k] = weight
for idx in w12_bias:
w1 = w12_bias[idx]['w1']
w2 = w12_bias[idx]['w2']
weight = torch.cat((w1, w2))
new_k = f'backbone.layers.{idx}.mlp.w12.bias'
new_ckpt[new_k] = weight
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained eva02 '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'module' in checkpoint:
state_dict = checkpoint['module']
else:
state_dict = checkpoint
weight = convert_eva02(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_eva(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
if 'decoder' in k or 'mask_token' in k:
continue
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('patch_embed'):
if 'proj.' in k:
new_k = k.replace('proj.', 'projection.')
else:
new_k = k
elif k.startswith('blocks'):
new_k = k.replace('blocks.', 'layers.')
if 'norm1' in k:
new_k = new_k.replace('norm1', 'ln1')
elif 'norm2' in k:
new_k = new_k.replace('norm2', 'ln2')
elif 'mlp.fc1' in k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
elif 'fc_norm' in k:
new_k = k.replace('fc_norm', 'ln2')
elif k.startswith('norm'):
# for mim pretrain
new_k = k.replace('norm', 'ln2')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained eva '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_eva(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_glip(ckpt):
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
if 'language_backbone' in k or 'backbone' not in k or 'fpn' in k:
continue
new_v = v
new_k = k.replace('body.', '')
new_k = new_k.replace('module.', '')
if new_k.startswith('backbone.layers'):
new_k = new_k.replace('backbone.layers', 'backbone.stages')
if 'mlp' in new_k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn' in new_k:
new_k = new_k.replace('attn', 'attn.w_msa')
elif 'patch_embed' in k:
new_k = new_k.replace('proj', 'projection')
elif 'downsample' in new_k:
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(new_v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(new_v)
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained glip models to mmcls style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_glip(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_hornet(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('norm'):
new_k = k.replace('norm.', 'norm3.')
elif 'gnconv.pws' in k:
new_k = k.replace('gnconv.pws', 'gnconv.projs')
elif 'gamma1' in k:
new_k = k.replace('gamma1', 'gamma1.weight')
elif 'gamma2' in k:
new_k = k.replace('gamma2', 'gamma2.weight')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained hornet '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_hornet(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
def convert_levit(args, ckpt):
new_ckpt = OrderedDict()
stage = 0
block = 0
change = True
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head_dist'):
new_k = k.replace('head_dist.', 'head.head_dist.')
new_k = new_k.replace('.l.', '.linear.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('head'):
new_k = k.replace('head.', 'head.head.')
new_k = new_k.replace('.l.', '.linear.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('patch_embed'):
new_k = k.replace('patch_embed.',
'patch_embed.patch_embed.').replace(
'.c.', '.conv.')
elif k.startswith('blocks'):
strs = k.split('.')
# new_k = k.replace('.c.', '.').replace('.bn.', '.')
new_k = k
if '.m.' in k:
new_k = new_k.replace('.m.0', '.m.linear1')
new_k = new_k.replace('.m.2', '.m.linear2')
new_k = new_k.replace('.m.', '.block.')
change = True
elif change:
stage += 1
block = int(strs[1])
change = False
new_k = new_k.replace(
'blocks.%s.' % (strs[1]),
'stages.%d.%d.' % (stage, int(strs[1]) - block))
new_k = new_k.replace('.c.', '.linear.')
else:
new_k = k
# print(new_k)
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMPretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = torch.load(args.src, map_location='cpu')
checkpoint = checkpoint['model']
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
weight = convert_levit(args, state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
from itertools import chain
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from transformers.modeling_utils import load_state_dict
prog_description = """\
Convert Llava weights and original weights.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('src', type=str, help='The original checkpoint dir')
parser.add_argument('dst', type=str, help='The saved checkpoint path')
parser.add_argument('--delta', type=str, help='The delta checkpoint dir')
args = parser.parse_args()
return args
def load_checkpoint(path: Path):
path = Path(path)
if path.is_file():
return torch.load(path)
state_dict = OrderedDict()
for ckpt in chain(
path.rglob('*.bin'), path.rglob('*.pth'),
path.rglob('*.safetensors')):
state_dict.update(load_state_dict(str(ckpt)))
return state_dict
def main():
args = parse_args()
if Path(args.src).exists():
src_path = args.src
else:
src_path = snapshot_download(
args.src, allow_patterns='pytorch_model*.bin')
src_state_dict = load_checkpoint(src_path)
if args.delta is None:
delta_state_dict = {}
elif Path(args.delta).exists():
delta_state_dict = load_checkpoint(args.delta)
else:
delta_path = snapshot_download(
args.delta, allow_patterns='pytorch_model*.bin')
delta_state_dict = load_checkpoint(delta_path)
new_state_dict = OrderedDict()
for k, v in src_state_dict.items():
if k in delta_state_dict:
delta_v = delta_state_dict.pop(k)
if k in ['model.embed_tokens.weight', 'lm_head.weight']:
h, w = v.shape[:2]
delta_v[:h, :w] += v
v = delta_v
else:
v += delta_v
if 'rotary_emb.inv_freq' not in k:
new_state_dict['model.lang_encoder.' + k] = v
for k, v in delta_state_dict.items():
new_state_dict['model.lang_encoder.' + k] = v
torch.save(new_state_dict, args.dst)
print('Done!!')
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment