Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmpretrain
Commits
1ac2e802
Commit
1ac2e802
authored
Jun 24, 2025
by
limm
Browse files
add tools code
parent
b6df0d33
Pipeline
#2803
canceled with stages
Changes
71
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1727 additions
and
0 deletions
+1727
-0
tools/model_converters/merge_lora_weight.py
tools/model_converters/merge_lora_weight.py
+90
-0
tools/model_converters/mixmim_to_mmpretrain.py
tools/model_converters/mixmim_to_mmpretrain.py
+99
-0
tools/model_converters/mlpmixer_to_mmpretrain.py
tools/model_converters/mlpmixer_to_mmpretrain.py
+58
-0
tools/model_converters/mobilenetv2_to_mmpretrain.py
tools/model_converters/mobilenetv2_to_mmpretrain.py
+135
-0
tools/model_converters/ofa.py
tools/model_converters/ofa.py
+111
-0
tools/model_converters/openai-clip_to_mmpretrain-clip.py
tools/model_converters/openai-clip_to_mmpretrain-clip.py
+77
-0
tools/model_converters/otter2mmpre.py
tools/model_converters/otter2mmpre.py
+66
-0
tools/model_converters/publish_model.py
tools/model_converters/publish_model.py
+108
-0
tools/model_converters/ram2mmpretrain.py
tools/model_converters/ram2mmpretrain.py
+117
-0
tools/model_converters/reparameterize_model.py
tools/model_converters/reparameterize_model.py
+57
-0
tools/model_converters/replknet_to_mmpretrain.py
tools/model_converters/replknet_to_mmpretrain.py
+58
-0
tools/model_converters/repvgg_to_mmpretrain.py
tools/model_converters/repvgg_to_mmpretrain.py
+60
-0
tools/model_converters/revvit_to_mmpretrain.py
tools/model_converters/revvit_to_mmpretrain.py
+99
-0
tools/model_converters/shufflenetv2_to_mmpretrain.py
tools/model_converters/shufflenetv2_to_mmpretrain.py
+113
-0
tools/model_converters/tinyvit_to_mmpretrain.py
tools/model_converters/tinyvit_to_mmpretrain.py
+61
-0
tools/model_converters/torchvision_to_mmpretrain.py
tools/model_converters/torchvision_to_mmpretrain.py
+63
-0
tools/model_converters/twins2mmpretrain.py
tools/model_converters/twins2mmpretrain.py
+73
-0
tools/model_converters/van2mmpretrain.py
tools/model_converters/van2mmpretrain.py
+66
-0
tools/model_converters/vgg_to_mmpretrain.py
tools/model_converters/vgg_to_mmpretrain.py
+118
-0
tools/model_converters/vig_to_mmpretrain.py
tools/model_converters/vig_to_mmpretrain.py
+98
-0
No files found.
tools/model_converters/merge_lora_weight.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
pathlib
import
Path
import
torch
from
mmengine.config
import
Config
from
mmpretrain.registry
import
MODELS
@
torch
.
no_grad
()
def
merge_lora_weight
(
cfg
,
lora_weight
):
"""Merge base weight and lora weight.
Args:
cfg (dict): config for LoRAModel.
lora_weight (dict): weight dict from LoRAModel.
Returns:
Merged weight.
"""
temp
=
dict
()
mapping
=
dict
()
for
name
,
param
in
lora_weight
[
'state_dict'
].
items
():
# backbone.module.layers.11.attn.qkv.lora_down.weight
if
'.lora_'
in
name
:
lora_split
=
name
.
split
(
'.'
)
prefix
=
'.'
.
join
(
lora_split
[:
-
2
])
if
prefix
not
in
mapping
:
mapping
[
prefix
]
=
dict
()
lora_type
=
lora_split
[
-
2
]
mapping
[
prefix
][
lora_type
]
=
param
else
:
temp
[
name
]
=
param
model
=
MODELS
.
build
(
cfg
[
'model'
])
for
name
,
param
in
model
.
named_parameters
():
if
name
in
temp
or
'.lora_'
in
name
:
continue
else
:
name_split
=
name
.
split
(
'.'
)
prefix
=
prefix
=
'.'
.
join
(
name_split
[:
-
2
])
if
prefix
in
mapping
:
name_split
.
pop
(
-
2
)
if
name_split
[
-
1
]
==
'weight'
:
scaling
=
get_scaling
(
model
,
prefix
)
lora_down
=
mapping
[
prefix
][
'lora_down'
]
lora_up
=
mapping
[
prefix
][
'lora_up'
]
param
+=
lora_up
@
lora_down
*
scaling
name_split
.
pop
(
1
)
name
=
'.'
.
join
(
name_split
)
temp
[
name
]
=
param
result
=
dict
()
result
[
'state_dict'
]
=
temp
result
[
'meta'
]
=
lora_weight
[
'meta'
]
return
result
def
get_scaling
(
model
,
prefix
):
"""Get the scaling of target layer.
Args:
model (LoRAModel): the LoRAModel.
prefix (str): the prefix of the layer.
Returns:
the scale of the LoRALinear.
"""
prefix_split
=
prefix
.
split
(
'.'
)
for
i
in
prefix_split
:
model
=
getattr
(
model
,
i
)
return
model
.
scaling
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Merge LoRA weight'
)
parser
.
add_argument
(
'cfg'
,
help
=
'cfg path'
)
parser
.
add_argument
(
'src'
,
help
=
'src lora model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
dst
=
Path
(
args
.
dst
)
if
dst
.
suffix
!=
'.pth'
:
print
(
'The path should contain the name of the pth format file.'
)
exit
(
1
)
dst
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
cfg
=
Config
.
fromfile
(
args
.
cfg
)
lora_model
=
torch
.
load
(
args
.
src
,
map_location
=
'cpu'
)
merged_model
=
merge_lora_weight
(
cfg
,
lora_model
)
torch
.
save
(
merged_model
,
args
.
dst
)
tools/model_converters/mixmim_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmengine
import
torch
from
mmengine.runner
import
CheckpointLoader
def
correct_unfold_reduction_order
(
x
:
torch
.
Tensor
):
out_channel
,
in_channel
=
x
.
shape
x
=
x
.
reshape
(
out_channel
,
4
,
in_channel
//
4
)
x
=
x
[:,
[
0
,
2
,
1
,
3
],
:].
transpose
(
1
,
2
).
reshape
(
out_channel
,
in_channel
)
return
x
def
correct_unfold_norm_order
(
x
):
in_channel
=
x
.
shape
[
0
]
x
=
x
.
reshape
(
4
,
in_channel
//
4
)
x
=
x
[[
0
,
2
,
1
,
3
],
:].
transpose
(
0
,
1
).
reshape
(
in_channel
)
return
x
def
convert_mixmim
(
ckpt
):
new_ckpt
=
OrderedDict
()
for
k
,
v
in
list
(
ckpt
.
items
()):
new_v
=
v
if
k
.
startswith
(
'patch_embed'
):
new_k
=
k
.
replace
(
'proj'
,
'projection'
)
elif
k
.
startswith
(
'layers'
):
if
'norm1'
in
k
:
new_k
=
k
.
replace
(
'norm1'
,
'ln1'
)
elif
'norm2'
in
k
:
new_k
=
k
.
replace
(
'norm2'
,
'ln2'
)
elif
'mlp.fc1'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc1'
,
'ffn.layers.0.0'
)
elif
'mlp.fc2'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc2'
,
'ffn.layers.1'
)
else
:
new_k
=
k
elif
k
.
startswith
(
'norm'
)
or
k
.
startswith
(
'absolute_pos_embed'
):
new_k
=
k
elif
k
.
startswith
(
'head'
):
new_k
=
k
.
replace
(
'head.'
,
'head.fc.'
)
else
:
raise
ValueError
# print(new_k)
if
not
new_k
.
startswith
(
'head'
):
new_k
=
'backbone.'
+
new_k
if
'downsample'
in
new_k
:
print
(
'Covert {} in PatchMerging from timm to mmcv format!'
.
format
(
new_k
))
if
'reduction'
in
new_k
:
new_v
=
correct_unfold_reduction_order
(
new_v
)
elif
'norm'
in
new_k
:
new_v
=
correct_unfold_norm_order
(
new_v
)
new_ckpt
[
new_k
]
=
new_v
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in pretrained mixmim '
'models to mmpretrain style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'model'
in
checkpoint
:
state_dict
=
checkpoint
[
'model'
]
else
:
state_dict
=
checkpoint
weight
=
convert_mixmim
(
state_dict
)
# weight = convert_official_mixmim(state_dict)
mmengine
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
print
(
'Done!!'
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/mlpmixer_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
pathlib
import
Path
import
torch
def
convert_weights
(
weight
):
"""Weight Converter.
Converts the weights from timm to mmpretrain
Args:
weight (dict): weight dict from timm
Returns: converted weight dict for mmpretrain
"""
result
=
dict
()
result
[
'meta'
]
=
dict
()
temp
=
dict
()
mapping
=
{
'stem'
:
'patch_embed'
,
'proj'
:
'projection'
,
'mlp_tokens.fc1'
:
'token_mix.layers.0.0'
,
'mlp_tokens.fc2'
:
'token_mix.layers.1'
,
'mlp_channels.fc1'
:
'channel_mix.layers.0.0'
,
'mlp_channels.fc2'
:
'channel_mix.layers.1'
,
'norm1'
:
'ln1'
,
'norm2'
:
'ln2'
,
'norm.'
:
'ln1.'
,
'blocks'
:
'layers'
}
for
k
,
v
in
weight
.
items
():
for
mk
,
mv
in
mapping
.
items
():
if
mk
in
k
:
k
=
k
.
replace
(
mk
,
mv
)
if
k
.
startswith
(
'head.'
):
temp
[
'head.fc.'
+
k
[
5
:]]
=
v
else
:
temp
[
'backbone.'
+
k
]
=
v
result
[
'state_dict'
]
=
temp
return
result
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
dst
=
Path
(
args
.
dst
)
if
dst
.
suffix
!=
'.pth'
:
print
(
'The path should contain the name of the pth format file.'
)
exit
(
1
)
dst
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
original_model
=
torch
.
load
(
args
.
src
,
map_location
=
'cpu'
)
converted_model
=
convert_weights
(
original_model
)
torch
.
save
(
converted_model
,
args
.
dst
)
tools/model_converters/mobilenetv2_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
collections
import
OrderedDict
import
torch
def
convert_conv1
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
if
model_key
.
find
(
'features.0.0'
)
>=
0
:
new_key
=
model_key
.
replace
(
'features.0.0'
,
'backbone.conv1.conv'
)
else
:
new_key
=
model_key
.
replace
(
'features.0.1'
,
'backbone.conv1.bn'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_conv5
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
if
model_key
.
find
(
'features.18.0'
)
>=
0
:
new_key
=
model_key
.
replace
(
'features.18.0'
,
'backbone.conv2.conv'
)
else
:
new_key
=
model_key
.
replace
(
'features.18.1'
,
'backbone.conv2.bn'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_head
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
new_key
=
model_key
.
replace
(
'classifier.1'
,
'head.fc'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_block
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
split_keys
=
model_key
.
split
(
'.'
)
layer_id
=
int
(
split_keys
[
1
])
new_layer_id
=
0
sub_id
=
0
if
layer_id
==
1
:
new_layer_id
=
1
sub_id
=
0
elif
layer_id
in
range
(
2
,
4
):
new_layer_id
=
2
sub_id
=
layer_id
-
2
elif
layer_id
in
range
(
4
,
7
):
new_layer_id
=
3
sub_id
=
layer_id
-
4
elif
layer_id
in
range
(
7
,
11
):
new_layer_id
=
4
sub_id
=
layer_id
-
7
elif
layer_id
in
range
(
11
,
14
):
new_layer_id
=
5
sub_id
=
layer_id
-
11
elif
layer_id
in
range
(
14
,
17
):
new_layer_id
=
6
sub_id
=
layer_id
-
14
elif
layer_id
==
17
:
new_layer_id
=
7
sub_id
=
0
new_key
=
model_key
.
replace
(
f
'features.
{
layer_id
}
'
,
f
'backbone.layer
{
new_layer_id
}
.
{
sub_id
}
'
)
if
new_layer_id
==
1
:
if
new_key
.
find
(
'conv.0.0'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.0.0'
,
'conv.0.conv'
)
elif
new_key
.
find
(
'conv.0.1'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.0.1'
,
'conv.0.bn'
)
elif
new_key
.
find
(
'conv.1'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.1'
,
'conv.1.conv'
)
elif
new_key
.
find
(
'conv.2'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.2'
,
'conv.1.bn'
)
else
:
raise
ValueError
(
f
'Unsupported conversion of key
{
model_key
}
'
)
else
:
if
new_key
.
find
(
'conv.0.0'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.0.0'
,
'conv.0.conv'
)
elif
new_key
.
find
(
'conv.0.1'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.0.1'
,
'conv.0.bn'
)
elif
new_key
.
find
(
'conv.1.0'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.1.0'
,
'conv.1.conv'
)
elif
new_key
.
find
(
'conv.1.1'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.1.1'
,
'conv.1.bn'
)
elif
new_key
.
find
(
'conv.2'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.2'
,
'conv.2.conv'
)
elif
new_key
.
find
(
'conv.3'
)
>=
0
:
new_key
=
new_key
.
replace
(
'conv.3'
,
'conv.2.bn'
)
else
:
raise
ValueError
(
f
'Unsupported conversion of key
{
model_key
}
'
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
def
convert
(
src
,
dst
):
"""Convert keys in torchvision pretrained MobileNetV2 models to mmpretrain
style."""
# load pytorch model
blobs
=
torch
.
load
(
src
,
map_location
=
'cpu'
)
# convert to pytorch style
state_dict
=
OrderedDict
()
converted_names
=
set
()
for
key
,
weight
in
blobs
.
items
():
if
'features.0'
in
key
:
convert_conv1
(
key
,
weight
,
state_dict
,
converted_names
)
elif
'classifier'
in
key
:
convert_head
(
key
,
weight
,
state_dict
,
converted_names
)
elif
'features.18'
in
key
:
convert_conv5
(
key
,
weight
,
state_dict
,
converted_names
)
else
:
convert_block
(
key
,
weight
,
state_dict
,
converted_names
)
# check if all layers are converted
for
key
in
blobs
:
if
key
not
in
converted_names
:
print
(
f
'not converted:
{
key
}
'
)
# save checkpoint
checkpoint
=
dict
()
checkpoint
[
'state_dict'
]
=
state_dict
torch
.
save
(
checkpoint
,
dst
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
convert
(
args
.
src
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/ofa.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
re
from
collections
import
OrderedDict
,
namedtuple
from
pathlib
import
Path
import
torch
prog_description
=
"""
\
Convert OFA official models to MMPretrain format.
"""
MapItem
=
namedtuple
(
'MapItem'
,
'pattern repl key_action value_action'
,
defaults
=
[
None
]
*
4
)
def
convert_by_mapdict
(
src_dict
:
dict
,
map_dict
:
Path
):
dst_dict
=
OrderedDict
()
convert_map_dict
=
dict
()
for
k
,
v
in
src_dict
.
items
():
ori_k
=
k
for
item
in
map_dict
:
pattern
=
item
.
pattern
assert
pattern
is
not
None
match
=
next
(
re
.
finditer
(
pattern
,
k
),
None
)
if
match
is
None
:
continue
match_group
=
match
.
groups
()
repl
=
item
.
repl
key_action
=
item
.
key_action
if
key_action
is
not
None
:
assert
callable
(
key_action
)
match_group
=
key_action
(
*
match_group
)
if
isinstance
(
match_group
,
str
):
match_group
=
(
match_group
,
)
start
,
end
=
match
.
span
(
0
)
if
repl
is
not
None
:
k
=
k
[:
start
]
+
repl
.
format
(
*
match_group
)
+
k
[
end
:]
else
:
for
i
,
sub
in
enumerate
(
match_group
):
start
,
end
=
match
.
span
(
i
+
1
)
k
=
k
[:
start
]
+
str
(
sub
)
+
k
[
end
:]
value_action
=
item
.
value_action
if
value_action
is
not
None
:
assert
callable
(
value_action
)
v
=
value_action
(
v
)
if
v
is
not
None
:
dst_dict
[
k
]
=
v
convert_map_dict
[
k
]
=
ori_k
return
dst_dict
,
convert_map_dict
map_dict
=
[
# Encoder modules
MapItem
(
r
'\.type_embedding\.'
,
'.embed_type.'
),
MapItem
(
r
'\.layernorm_embedding\.'
,
'.embedding_ln.'
),
MapItem
(
r
'\.patch_layernorm_embedding\.'
,
'.image_embedding_ln.'
),
MapItem
(
r
'encoder.layer_norm\.'
,
'encoder.final_ln.'
),
# Encoder layers
MapItem
(
r
'\.attn_ln\.'
,
'.attn_mid_ln.'
),
MapItem
(
r
'\.ffn_layernorm\.'
,
'.ffn_mid_ln.'
),
MapItem
(
r
'\.final_layer_norm'
,
'.ffn_ln'
),
MapItem
(
r
'encoder.*(\.self_attn\.)'
,
key_action
=
lambda
_
:
'.attn.'
),
MapItem
(
r
'encoder.*(\.self_attn_layer_norm\.)'
,
key_action
=
lambda
_
:
'.attn_ln.'
),
# Decoder modules
MapItem
(
r
'\.code_layernorm_embedding\.'
,
'.code_embedding_ln.'
),
MapItem
(
r
'decoder.layer_norm\.'
,
'decoder.final_ln.'
),
# Decoder layers
MapItem
(
r
'\.self_attn_ln'
,
'.self_attn_mid_ln'
),
MapItem
(
r
'\.cross_attn_ln'
,
'.cross_attn_mid_ln'
),
MapItem
(
r
'\.encoder_attn_layer_norm'
,
'.cross_attn_ln'
),
MapItem
(
r
'\.encoder_attn'
,
'.cross_attn'
),
MapItem
(
r
'decoder.*(\.self_attn_layer_norm\.)'
,
key_action
=
lambda
_
:
'.self_attn_ln.'
),
# Remove version key
MapItem
(
r
'version'
,
''
,
value_action
=
lambda
_
:
None
),
# Add model prefix
MapItem
(
r
'^'
,
'model.'
),
]
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
prog_description
)
parser
.
add_argument
(
'src'
,
type
=
str
,
help
=
'The official checkpoint path.'
)
parser
.
add_argument
(
'dst'
,
type
=
str
,
help
=
'The save path.'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
src
=
torch
.
load
(
args
.
src
)
if
'extra_state'
in
src
and
'ema'
in
src
[
'extra_state'
]:
print
(
'Use EMA weights.'
)
src
=
src
[
'extra_state'
][
'ema'
]
else
:
src
=
src
[
'model'
]
dst
,
_
=
convert_by_mapdict
(
src
,
map_dict
)
torch
.
save
(
dst
,
args
.
dst
)
print
(
'Done!!'
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/openai-clip_to_mmpretrain-clip.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmengine
import
torch
from
mmengine.runner
import
CheckpointLoader
def
convert_clip
(
ckpt
):
new_ckpt
=
OrderedDict
()
for
k
,
v
in
list
(
ckpt
.
items
()):
new_v
=
v
if
k
.
startswith
(
'visual.conv1'
):
new_k
=
k
.
replace
(
'conv1'
,
'patch_embed.projection'
)
elif
k
.
startswith
(
'visual.positional_embedding'
):
new_k
=
k
.
replace
(
'positional_embedding'
,
'pos_embed'
)
new_v
=
v
.
unsqueeze
(
dim
=
0
)
elif
k
.
startswith
(
'visual.class_embedding'
):
new_k
=
k
.
replace
(
'class_embedding'
,
'cls_token'
)
new_v
=
v
.
unsqueeze
(
dim
=
0
).
unsqueeze
(
dim
=
0
)
elif
k
.
startswith
(
'visual.ln_pre'
):
new_k
=
k
.
replace
(
'ln_pre'
,
'pre_norm'
)
elif
k
.
startswith
(
'visual.transformer.resblocks'
):
new_k
=
k
.
replace
(
'transformer.resblocks'
,
'layers'
)
if
'ln_1'
in
k
:
new_k
=
new_k
.
replace
(
'ln_1'
,
'ln1'
)
elif
'ln_2'
in
k
:
new_k
=
new_k
.
replace
(
'ln_2'
,
'ln2'
)
elif
'mlp.c_fc'
in
k
:
new_k
=
new_k
.
replace
(
'mlp.c_fc'
,
'ffn.layers.0.0'
)
elif
'mlp.c_proj'
in
k
:
new_k
=
new_k
.
replace
(
'mlp.c_proj'
,
'ffn.layers.1'
)
elif
'attn.in_proj_weight'
in
k
:
new_k
=
new_k
.
replace
(
'in_proj_weight'
,
'qkv.weight'
)
elif
'attn.in_proj_bias'
in
k
:
new_k
=
new_k
.
replace
(
'in_proj_bias'
,
'qkv.bias'
)
elif
'attn.out_proj'
in
k
:
new_k
=
new_k
.
replace
(
'out_proj'
,
'proj'
)
elif
k
.
startswith
(
'visual.ln_post'
):
new_k
=
k
.
replace
(
'ln_post'
,
'ln1'
)
elif
k
.
startswith
(
'visual.proj'
):
new_k
=
k
.
replace
(
'visual.proj'
,
'visual_proj.proj'
)
else
:
new_k
=
k
new_ckpt
[
new_k
]
=
new_v
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in pretrained clip '
'models to mmpretrain style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
else
:
state_dict
=
checkpoint
weight
=
convert_clip
(
state_dict
)
mmengine
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
print
(
'Done!!'
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/otter2mmpre.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
re
from
collections
import
OrderedDict
from
itertools
import
chain
from
pathlib
import
Path
import
torch
prog_description
=
"""
\
Convert Official Otter HF models to MMPreTrain format.
"""
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
prog_description
)
parser
.
add_argument
(
'name_or_dir'
,
type
=
str
,
help
=
'The Otter HF model name or directory.'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
if
not
Path
(
args
.
name_or_dir
).
is_dir
():
from
huggingface_hub
import
snapshot_download
ckpt_dir
=
Path
(
snapshot_download
(
args
.
name_or_dir
,
allow_patterns
=
'*.bin'
))
name
=
args
.
name_or_dir
.
replace
(
'/'
,
'_'
)
else
:
ckpt_dir
=
Path
(
args
.
name_or_dir
)
name
=
ckpt_dir
.
name
state_dict
=
OrderedDict
()
for
k
,
v
in
chain
.
from_iterable
(
torch
.
load
(
ckpt
).
items
()
for
ckpt
in
ckpt_dir
.
glob
(
'*.bin'
)):
adapter_patterns
=
[
r
'^perceiver'
,
r
'lang_encoder.*embed_tokens'
,
r
'lang_encoder.*gated_cross_attn_layer'
,
r
'lang_encoder.*rotary_emb'
,
]
if
not
any
(
re
.
match
(
pattern
,
k
)
for
pattern
in
adapter_patterns
):
# Drop encoder parameters to decrease the size.
continue
# The keys are different between Open-Flamingo and Otter
if
'gated_cross_attn_layer.feed_forward'
in
k
:
k
=
k
.
replace
(
'feed_forward'
,
'ff'
)
if
'perceiver.layers'
in
k
:
prefix_match
=
re
.
match
(
r
'perceiver.layers.\d+.'
,
k
)
prefix
=
k
[:
prefix_match
.
end
()]
suffix
=
k
[
prefix_match
.
end
():]
if
'feed_forward'
in
k
:
k
=
prefix
+
'1.'
+
suffix
.
replace
(
'feed_forward.'
,
''
)
else
:
k
=
prefix
+
'0.'
+
suffix
state_dict
[
k
]
=
v
if
len
(
state_dict
)
==
0
:
raise
RuntimeError
(
'No checkpoint found in the specified directory.'
)
torch
.
save
(
state_dict
,
name
+
'.pth'
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/publish_model.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
datetime
import
hashlib
import
shutil
import
warnings
from
collections
import
OrderedDict
from
pathlib
import
Path
import
torch
import
mmpretrain
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Process a checkpoint to be published'
)
parser
.
add_argument
(
'in_file'
,
help
=
'input checkpoint filename'
)
parser
.
add_argument
(
'out_file'
,
help
=
'output checkpoint filename'
)
parser
.
add_argument
(
'--no-ema'
,
action
=
'store_true'
,
help
=
'Use keys in `ema_state_dict` (no-ema keys).'
)
parser
.
add_argument
(
'--dataset-type'
,
type
=
str
,
help
=
'The type of the dataset. If the checkpoint is converted '
'from other repository, this option is used to fill the dataset '
'meta information to the published checkpoint, like "ImageNet", '
'"CIFAR10" and others.'
)
args
=
parser
.
parse_args
()
return
args
def
process_checkpoint
(
in_file
,
out_file
,
args
):
checkpoint
=
torch
.
load
(
in_file
,
map_location
=
'cpu'
)
# remove unnecessary fields for smaller file size
for
key
in
[
'optimizer'
,
'param_schedulers'
,
'hook_msgs'
,
'message_hub'
]:
checkpoint
.
pop
(
key
,
None
)
# For checkpoint converted from the official weight
if
'state_dict'
not
in
checkpoint
:
checkpoint
=
dict
(
state_dict
=
checkpoint
)
meta
=
checkpoint
.
get
(
'meta'
,
{})
meta
.
setdefault
(
'mmpretrain_version'
,
mmpretrain
.
__version__
)
# handle dataset meta information
if
args
.
dataset_type
is
not
None
:
from
mmpretrain.registry
import
DATASETS
dataset_class
=
DATASETS
.
get
(
args
.
dataset_type
)
dataset_meta
=
getattr
(
dataset_class
,
'METAINFO'
,
{})
else
:
dataset_meta
=
{}
meta
.
setdefault
(
'dataset_meta'
,
dataset_meta
)
if
len
(
meta
[
'dataset_meta'
])
==
0
:
warnings
.
warn
(
'Missing dataset meta information.'
)
checkpoint
[
'meta'
]
=
meta
ema_state_dict
=
OrderedDict
()
if
'ema_state_dict'
in
checkpoint
:
for
k
,
v
in
checkpoint
[
'ema_state_dict'
].
items
():
# The ema static dict has some extra fields
if
k
.
startswith
(
'module.'
):
origin_k
=
k
[
len
(
'module.'
):]
assert
origin_k
in
checkpoint
[
'state_dict'
]
ema_state_dict
[
origin_k
]
=
v
del
checkpoint
[
'ema_state_dict'
]
print
(
'The input checkpoint has EMA weights, '
,
end
=
''
)
if
args
.
no_ema
:
# The values stored in `ema_state_dict` is original values.
print
(
'and drop the EMA weights.'
)
assert
ema_state_dict
.
keys
()
<=
checkpoint
[
'state_dict'
].
keys
()
checkpoint
[
'state_dict'
].
update
(
ema_state_dict
)
else
:
print
(
'and use the EMA weights.'
)
temp_out_file
=
Path
(
out_file
).
with_name
(
'temp_'
+
Path
(
out_file
).
name
)
torch
.
save
(
checkpoint
,
temp_out_file
)
with
open
(
temp_out_file
,
'rb'
)
as
f
:
sha
=
hashlib
.
sha256
(
f
.
read
()).
hexdigest
()[:
8
]
if
out_file
.
endswith
(
'.pth'
):
out_file_name
=
out_file
[:
-
4
]
else
:
out_file_name
=
out_file
current_date
=
datetime
.
datetime
.
now
().
strftime
(
'%Y%m%d'
)
final_file
=
out_file_name
+
f
'_
{
current_date
}
-
{
sha
[:
8
]
}
.pth'
shutil
.
move
(
temp_out_file
,
final_file
)
print
(
f
'Successfully generated the publish-ckpt as
{
final_file
}
.'
)
def
main
():
args
=
parse_args
()
out_dir
=
Path
(
args
.
out_file
).
parent
if
not
out_dir
.
exists
():
raise
ValueError
(
f
'Directory
{
out_dir
}
does not exist, '
'please generate it manually.'
)
process_checkpoint
(
args
.
in_file
,
args
.
out_file
,
args
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/ram2mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
from
copy
import
deepcopy
import
mmengine
import
torch
from
mmengine.runner
import
CheckpointLoader
def
convert_swin
(
ckpt
):
new_ckpt
=
OrderedDict
()
convert_mapping
=
dict
()
def
correct_unfold_reduction_order
(
x
):
out_channel
,
in_channel
=
x
.
shape
x
=
x
.
reshape
(
out_channel
,
4
,
in_channel
//
4
)
x
=
x
[:,
[
0
,
2
,
1
,
3
],
:].
transpose
(
1
,
2
).
reshape
(
out_channel
,
in_channel
)
return
x
def
correct_unfold_norm_order
(
x
):
in_channel
=
x
.
shape
[
0
]
x
=
x
.
reshape
(
4
,
in_channel
//
4
)
x
=
x
[[
0
,
2
,
1
,
3
],
:].
transpose
(
0
,
1
).
reshape
(
in_channel
)
return
x
for
k
,
v
in
ckpt
.
items
():
if
'attn_mask'
in
k
:
continue
if
k
.
startswith
(
'head'
):
continue
elif
k
.
startswith
(
'layers'
):
new_v
=
v
if
'attn.'
in
k
:
new_k
=
k
.
replace
(
'attn.'
,
'attn.w_msa.'
)
elif
'mlp.'
in
k
:
if
'mlp.fc1.'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc1.'
,
'ffn.layers.0.0.'
)
elif
'mlp.fc2.'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc2.'
,
'ffn.layers.1.'
)
else
:
new_k
=
k
.
replace
(
'mlp.'
,
'ffn.'
)
elif
'downsample'
in
k
:
new_k
=
k
if
'reduction.'
in
k
:
new_v
=
correct_unfold_reduction_order
(
v
)
elif
'norm.'
in
k
:
new_v
=
correct_unfold_norm_order
(
v
)
else
:
new_k
=
k
new_k
=
new_k
.
replace
(
'layers'
,
'stages'
,
1
)
elif
k
.
startswith
(
'patch_embed'
):
new_v
=
v
if
'proj'
in
k
:
new_k
=
k
.
replace
(
'proj'
,
'projection'
)
else
:
new_k
=
k
elif
k
.
startswith
(
'norm'
):
new_v
=
v
new_k
=
k
.
replace
(
'norm'
,
'norm3'
)
else
:
new_v
=
v
new_k
=
k
new_ckpt
[
new_k
]
=
new_v
convert_mapping
[
k
]
=
new_k
return
new_ckpt
,
convert_mapping
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in official pretrained RAM models to'
'MMPretrain style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
elif
'model'
in
checkpoint
:
state_dict
=
checkpoint
[
'model'
]
else
:
state_dict
=
checkpoint
visual_ckpt
=
OrderedDict
()
for
key
in
state_dict
:
if
key
.
startswith
(
'visual_encoder.'
):
new_key
=
key
.
replace
(
'visual_encoder.'
,
''
)
visual_ckpt
[
new_key
]
=
state_dict
[
key
]
new_visual_ckpt
,
convert_mapping
=
convert_swin
(
visual_ckpt
)
new_ckpt
=
deepcopy
(
state_dict
)
for
key
in
state_dict
:
if
key
.
startswith
(
'visual_encoder.'
):
if
'attn_mask'
in
key
:
del
new_ckpt
[
key
]
continue
del
new_ckpt
[
key
]
old_key
=
key
.
replace
(
'visual_encoder.'
,
''
)
new_ckpt
[
key
.
replace
(
old_key
,
convert_mapping
[
old_key
])]
=
deepcopy
(
new_visual_ckpt
[
key
.
replace
(
old_key
,
convert_mapping
[
old_key
]).
replace
(
'visual_encoder.'
,
''
)])
mmengine
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
new_ckpt
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/reparameterize_model.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
pathlib
import
Path
import
torch
from
mmpretrain.apis
import
init_model
from
mmpretrain.models.classifiers
import
ImageClassifier
def
convert_classifier_to_deploy
(
model
,
checkpoint
,
save_path
):
print
(
'Converting...'
)
assert
hasattr
(
model
,
'backbone'
)
and
\
hasattr
(
model
.
backbone
,
'switch_to_deploy'
),
\
'`model.backbone` must has method of "switch_to_deploy".'
\
f
' But
{
model
.
backbone
.
__class__
}
does not have.'
model
.
backbone
.
switch_to_deploy
()
checkpoint
[
'state_dict'
]
=
model
.
state_dict
()
torch
.
save
(
checkpoint
,
save_path
)
print
(
'Done! Save at path "{}"'
.
format
(
save_path
))
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert the parameters of the repvgg block '
'from training mode to deployment mode.'
)
parser
.
add_argument
(
'config_path'
,
help
=
'The path to the configuration file of the network '
'containing the repvgg block.'
)
parser
.
add_argument
(
'checkpoint_path'
,
help
=
'The path to the checkpoint file corresponding to the model.'
)
parser
.
add_argument
(
'save_path'
,
help
=
'The path where the converted checkpoint file is stored.'
)
args
=
parser
.
parse_args
()
save_path
=
Path
(
args
.
save_path
)
if
save_path
.
suffix
!=
'.pth'
and
save_path
.
suffix
!=
'.tar'
:
print
(
'The path should contain the name of the pth format file.'
)
exit
()
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
=
init_model
(
args
.
config_path
,
checkpoint
=
args
.
checkpoint_path
,
device
=
'cpu'
)
assert
isinstance
(
model
,
ImageClassifier
),
\
'`model` must be a `mmpretrain.classifiers.ImageClassifier` instance.'
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
convert_classifier_to_deploy
(
model
,
checkpoint
,
args
.
save_path
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/replknet_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
collections
import
OrderedDict
from
pathlib
import
Path
import
torch
def
convert
(
src
,
dst
):
print
(
'Converting...'
)
blobs
=
torch
.
load
(
src
,
map_location
=
'cpu'
)
converted_state_dict
=
OrderedDict
()
for
key
in
blobs
:
splited_key
=
key
.
split
(
'.'
)
print
(
splited_key
)
splited_key
=
[
'backbone.stem'
if
i
[:
4
]
==
'stem'
else
i
for
i
in
splited_key
]
splited_key
=
[
'backbone.stages'
if
i
[:
6
]
==
'stages'
else
i
for
i
in
splited_key
]
splited_key
=
[
'backbone.transitions'
if
i
[:
11
]
==
'transitions'
else
i
for
i
in
splited_key
]
splited_key
=
[
'backbone.stages.3.norm'
if
i
[:
4
]
==
'norm'
else
i
for
i
in
splited_key
]
splited_key
=
[
'head.fc'
if
i
[:
4
]
==
'head'
else
i
for
i
in
splited_key
]
new_key
=
'.'
.
join
(
splited_key
)
converted_state_dict
[
new_key
]
=
blobs
[
key
]
torch
.
save
(
converted_state_dict
,
dst
)
print
(
'Done!'
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
dst
=
Path
(
args
.
dst
)
if
dst
.
suffix
!=
'.pth'
:
print
(
'The path should contain the name of the pth format file.'
)
exit
(
1
)
dst
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
convert
(
args
.
src
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/repvgg_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
collections
import
OrderedDict
from
pathlib
import
Path
import
torch
def
convert
(
src
,
dst
):
print
(
'Converting...'
)
blobs
=
torch
.
load
(
src
,
map_location
=
'cpu'
)
converted_state_dict
=
OrderedDict
()
for
key
in
blobs
:
splited_key
=
key
.
split
(
'.'
)
splited_key
=
[
'norm'
if
i
==
'bn'
else
i
for
i
in
splited_key
]
splited_key
=
[
'branch_norm'
if
i
==
'rbr_identity'
else
i
for
i
in
splited_key
]
splited_key
=
[
'branch_1x1'
if
i
==
'rbr_1x1'
else
i
for
i
in
splited_key
]
splited_key
=
[
'branch_3x3'
if
i
==
'rbr_dense'
else
i
for
i
in
splited_key
]
splited_key
=
[
'backbone.stem'
if
i
[:
6
]
==
'stage0'
else
i
for
i
in
splited_key
]
splited_key
=
[
'backbone.stage_'
+
i
[
5
]
if
i
[:
5
]
==
'stage'
else
i
for
i
in
splited_key
]
splited_key
=
[
'se_layer'
if
i
==
'se'
else
i
for
i
in
splited_key
]
splited_key
=
[
'conv1.conv'
if
i
==
'down'
else
i
for
i
in
splited_key
]
splited_key
=
[
'conv2.conv'
if
i
==
'up'
else
i
for
i
in
splited_key
]
splited_key
=
[
'head.fc'
if
i
==
'linear'
else
i
for
i
in
splited_key
]
new_key
=
'.'
.
join
(
splited_key
)
converted_state_dict
[
new_key
]
=
blobs
[
key
]
torch
.
save
(
converted_state_dict
,
dst
)
print
(
'Done!'
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
dst
=
Path
(
args
.
dst
)
if
dst
.
suffix
!=
'.pth'
:
print
(
'The path should contain the name of the pth format file.'
)
exit
(
1
)
dst
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
convert
(
args
.
src
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/revvit_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmengine
import
torch
from
mmengine.runner
import
CheckpointLoader
def
convert_revvit
(
ckpt
):
new_ckpt
=
OrderedDict
()
for
k
,
v
in
list
(
ckpt
.
items
()):
new_v
=
v
if
k
.
startswith
(
'head.projection'
):
new_k
=
k
.
replace
(
'head.projection'
,
'head.fc'
)
new_ckpt
[
new_k
]
=
new_v
continue
elif
k
.
startswith
(
'patch_embed'
):
if
'proj.'
in
k
:
new_k
=
k
.
replace
(
'proj.'
,
'projection.'
)
else
:
new_k
=
k
elif
k
.
startswith
(
'rev_backbone'
):
new_k
=
k
.
replace
(
'rev_backbone.'
,
''
)
if
'F.norm'
in
k
:
new_k
=
new_k
.
replace
(
'F.norm'
,
'ln1'
)
elif
'G.norm'
in
k
:
new_k
=
new_k
.
replace
(
'G.norm'
,
'ln2'
)
elif
'F.attn'
in
k
:
new_k
=
new_k
.
replace
(
'F.attn'
,
'attn'
)
elif
'G.mlp.fc1'
in
k
:
new_k
=
new_k
.
replace
(
'G.mlp.fc1'
,
'ffn.layers.0.0'
)
elif
'G.mlp.fc2'
in
k
:
new_k
=
new_k
.
replace
(
'G.mlp.fc2'
,
'ffn.layers.1'
)
elif
k
.
startswith
(
'norm'
):
new_k
=
k
.
replace
(
'norm'
,
'ln1'
)
else
:
new_k
=
k
if
not
new_k
.
startswith
(
'head'
):
new_k
=
'backbone.'
+
new_k
new_ckpt
[
new_k
]
=
new_v
tmp_weight_dir
=
[]
tmp_bias_dir
=
[]
final_ckpt
=
OrderedDict
()
for
k
,
v
in
list
(
new_ckpt
.
items
()):
if
'attn.q.weight'
in
k
:
tmp_weight_dir
.
append
(
v
)
elif
'attn.k.weight'
in
k
:
tmp_weight_dir
.
append
(
v
)
elif
'attn.v.weight'
in
k
:
tmp_weight_dir
.
append
(
v
)
new_k
=
k
.
replace
(
'attn.v.weight'
,
'attn.qkv.weight'
)
final_ckpt
[
new_k
]
=
torch
.
cat
(
tmp_weight_dir
,
dim
=
0
)
tmp_weight_dir
=
[]
elif
'attn.q.bias'
in
k
:
tmp_bias_dir
.
append
(
v
)
elif
'attn.k.bias'
in
k
:
tmp_bias_dir
.
append
(
v
)
elif
'attn.v.bias'
in
k
:
tmp_bias_dir
.
append
(
v
)
new_k
=
k
.
replace
(
'attn.v.bias'
,
'attn.qkv.bias'
)
final_ckpt
[
new_k
]
=
torch
.
cat
(
tmp_bias_dir
,
dim
=
0
)
tmp_bias_dir
=
[]
else
:
final_ckpt
[
k
]
=
v
return
final_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in pretrained revvit'
' models to mmpretrain style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'model_state'
in
checkpoint
:
state_dict
=
checkpoint
[
'model_state'
]
else
:
state_dict
=
checkpoint
weight
=
convert_revvit
(
state_dict
)
mmengine
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
print
(
'Done!!'
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/shufflenetv2_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
collections
import
OrderedDict
import
torch
def
convert_conv1
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
if
model_key
.
find
(
'conv1.0'
)
>=
0
:
new_key
=
model_key
.
replace
(
'conv1.0'
,
'backbone.conv1.conv'
)
else
:
new_key
=
model_key
.
replace
(
'conv1.1'
,
'backbone.conv1.bn'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_conv5
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
if
model_key
.
find
(
'conv5.0'
)
>=
0
:
new_key
=
model_key
.
replace
(
'conv5.0'
,
'backbone.layers.3.conv'
)
else
:
new_key
=
model_key
.
replace
(
'conv5.1'
,
'backbone.layers.3.bn'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_head
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
new_key
=
model_key
.
replace
(
'fc'
,
'head.fc'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
def
convert_block
(
model_key
,
model_weight
,
state_dict
,
converted_names
):
split_keys
=
model_key
.
split
(
'.'
)
layer
,
block
,
branch
=
split_keys
[:
3
]
layer_id
=
int
(
layer
[
-
1
])
-
2
new_key
=
model_key
.
replace
(
layer
,
f
'backbone.layers.
{
layer_id
}
'
)
if
branch
==
'branch1'
:
if
new_key
.
find
(
'branch1.0'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch1.0'
,
'branch1.0.conv'
)
elif
new_key
.
find
(
'branch1.1'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch1.1'
,
'branch1.0.bn'
)
elif
new_key
.
find
(
'branch1.2'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch1.2'
,
'branch1.1.conv'
)
elif
new_key
.
find
(
'branch1.3'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch1.3'
,
'branch1.1.bn'
)
elif
branch
==
'branch2'
:
if
new_key
.
find
(
'branch2.0'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch2.0'
,
'branch2.0.conv'
)
elif
new_key
.
find
(
'branch2.1'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch2.1'
,
'branch2.0.bn'
)
elif
new_key
.
find
(
'branch2.3'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch2.3'
,
'branch2.1.conv'
)
elif
new_key
.
find
(
'branch2.4'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch2.4'
,
'branch2.1.bn'
)
elif
new_key
.
find
(
'branch2.5'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch2.5'
,
'branch2.2.conv'
)
elif
new_key
.
find
(
'branch2.6'
)
>=
0
:
new_key
=
new_key
.
replace
(
'branch2.6'
,
'branch2.2.bn'
)
else
:
raise
ValueError
(
f
'Unsupported conversion of key
{
model_key
}
'
)
else
:
raise
ValueError
(
f
'Unsupported conversion of key
{
model_key
}
'
)
print
(
f
'Convert
{
model_key
}
to
{
new_key
}
'
)
state_dict
[
new_key
]
=
model_weight
converted_names
.
add
(
model_key
)
def
convert
(
src
,
dst
):
"""Convert keys in torchvision pretrained ShuffleNetV2 models to mmpretrain
style."""
# load pytorch model
blobs
=
torch
.
load
(
src
,
map_location
=
'cpu'
)
# convert to pytorch style
state_dict
=
OrderedDict
()
converted_names
=
set
()
for
key
,
weight
in
blobs
.
items
():
if
'conv1'
in
key
:
convert_conv1
(
key
,
weight
,
state_dict
,
converted_names
)
elif
'fc'
in
key
:
convert_head
(
key
,
weight
,
state_dict
,
converted_names
)
elif
key
.
startswith
(
's'
):
convert_block
(
key
,
weight
,
state_dict
,
converted_names
)
elif
'conv5'
in
key
:
convert_conv5
(
key
,
weight
,
state_dict
,
converted_names
)
# check if all layers are converted
for
key
in
blobs
:
if
key
not
in
converted_names
:
print
(
f
'not converted:
{
key
}
'
)
# save checkpoint
checkpoint
=
dict
()
checkpoint
[
'state_dict'
]
=
state_dict
torch
.
save
(
checkpoint
,
dst
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
convert
(
args
.
src
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/tinyvit_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
pathlib
import
Path
import
torch
def
convert_weights
(
weight
):
"""Weight Converter.
Converts the weights from timm to mmpretrain
Args:
weight (dict): weight dict from timm
Returns:
Converted weight dict for mmpretrain
"""
result
=
dict
()
result
[
'meta'
]
=
dict
()
temp
=
dict
()
mapping
=
{
'c.weight'
:
'conv2d.weight'
,
'bn.weight'
:
'bn2d.weight'
,
'bn.bias'
:
'bn2d.bias'
,
'bn.running_mean'
:
'bn2d.running_mean'
,
'bn.running_var'
:
'bn2d.running_var'
,
'bn.num_batches_tracked'
:
'bn2d.num_batches_tracked'
,
'layers'
:
'stages'
,
'norm_head'
:
'norm3'
,
}
weight
=
weight
[
'model'
]
for
k
,
v
in
weight
.
items
():
# keyword mapping
for
mk
,
mv
in
mapping
.
items
():
if
mk
in
k
:
k
=
k
.
replace
(
mk
,
mv
)
if
k
.
startswith
(
'head.'
):
temp
[
'head.fc.'
+
k
[
5
:]]
=
v
else
:
temp
[
'backbone.'
+
k
]
=
v
result
[
'state_dict'
]
=
temp
return
result
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
dst
=
Path
(
args
.
dst
)
if
dst
.
suffix
!=
'.pth'
:
print
(
'The path should contain the name of the pth format file.'
)
exit
(
1
)
dst
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
original_model
=
torch
.
load
(
args
.
src
,
map_location
=
'cpu'
)
converted_model
=
convert_weights
(
original_model
)
torch
.
save
(
converted_model
,
args
.
dst
)
tools/model_converters/torchvision_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
collections
import
OrderedDict
from
pathlib
import
Path
import
torch
def
convert_resnet
(
src_dict
,
dst_dict
):
"""convert resnet checkpoints from torchvision."""
for
key
,
value
in
src_dict
.
items
():
if
not
key
.
startswith
(
'fc'
):
dst_dict
[
'backbone.'
+
key
]
=
value
else
:
dst_dict
[
'head.'
+
key
]
=
value
# model name to convert function
CONVERT_F_DICT
=
{
'resnet'
:
convert_resnet
,
}
def
convert
(
src
:
str
,
dst
:
str
,
convert_f
:
callable
):
print
(
'Converting...'
)
blobs
=
torch
.
load
(
src
,
map_location
=
'cpu'
)
converted_state_dict
=
OrderedDict
()
# convert key in weight
convert_f
(
blobs
,
converted_state_dict
)
torch
.
save
(
converted_state_dict
,
dst
)
print
(
'Done!'
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src detectron model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
parser
.
add_argument
(
'model'
,
type
=
str
,
help
=
'The algorithm needs to change the keys.'
)
args
=
parser
.
parse_args
()
dst
=
Path
(
args
.
dst
)
if
dst
.
suffix
!=
'.pth'
:
print
(
'The path should contain the name of the pth format file.'
)
exit
(
1
)
dst
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# this tool only support model in CONVERT_F_DICT
support_models
=
list
(
CONVERT_F_DICT
.
keys
())
if
args
.
model
not
in
CONVERT_F_DICT
:
print
(
f
'The "
{
args
.
model
}
" has not been supported to convert now.'
)
print
(
f
'This tool only supports
{
", "
.
join
(
support_models
)
}
.'
)
print
(
'If you have done the converting job, PR is welcome!'
)
exit
(
1
)
convert_f
=
CONVERT_F_DICT
[
args
.
model
]
convert
(
args
.
src
,
args
.
dst
,
convert_f
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/twins2mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmcv
import
torch
from
mmcv.runner
import
CheckpointLoader
def
convert_twins
(
args
,
ckpt
):
new_ckpt
=
OrderedDict
()
for
k
,
v
in
list
(
ckpt
.
items
()):
new_v
=
v
if
k
.
startswith
(
'head'
):
new_k
=
k
.
replace
(
'head.'
,
'head.fc.'
)
new_ckpt
[
new_k
]
=
new_v
continue
elif
k
.
startswith
(
'patch_embeds'
):
if
'proj.'
in
k
:
new_k
=
k
.
replace
(
'proj.'
,
'projection.'
)
else
:
new_k
=
k
elif
k
.
startswith
(
'blocks'
):
k
=
k
.
replace
(
'blocks'
,
'stages'
)
# Union
if
'mlp.fc1'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc1'
,
'ffn.layers.0.0'
)
elif
'mlp.fc2'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc2'
,
'ffn.layers.1'
)
else
:
new_k
=
k
new_k
=
new_k
.
replace
(
'blocks.'
,
'layers.'
)
elif
k
.
startswith
(
'pos_block'
):
new_k
=
k
.
replace
(
'pos_block'
,
'position_encodings'
)
if
'proj.0.'
in
new_k
:
new_k
=
new_k
.
replace
(
'proj.0.'
,
'proj.'
)
elif
k
.
startswith
(
'norm'
):
new_k
=
k
.
replace
(
'norm'
,
'norm_after_stage3'
)
else
:
new_k
=
k
new_k
=
'backbone.'
+
new_k
new_ckpt
[
new_k
]
=
new_v
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in timm pretrained vit models to '
'MMPretrain style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'state_dict'
in
checkpoint
:
# timm checkpoint
state_dict
=
checkpoint
[
'state_dict'
]
else
:
state_dict
=
checkpoint
weight
=
convert_twins
(
args
,
state_dict
)
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/van2mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
from
collections
import
OrderedDict
import
mmengine
import
torch
from
mmengine.runner
import
CheckpointLoader
def
convert_van
(
ckpt
):
new_ckpt
=
OrderedDict
()
for
k
,
v
in
list
(
ckpt
.
items
()):
new_v
=
v
if
k
.
startswith
(
'head'
):
new_k
=
k
.
replace
(
'head.'
,
'head.fc.'
)
new_ckpt
[
new_k
]
=
new_v
continue
elif
k
.
startswith
(
'patch_embed'
):
if
'proj.'
in
k
:
new_k
=
k
.
replace
(
'proj.'
,
'projection.'
)
else
:
new_k
=
k
elif
k
.
startswith
(
'block'
):
new_k
=
k
.
replace
(
'block'
,
'blocks'
)
if
'attn.spatial_gating_unit'
in
new_k
:
new_k
=
new_k
.
replace
(
'conv0'
,
'DW_conv'
)
new_k
=
new_k
.
replace
(
'conv_spatial'
,
'DW_D_conv'
)
if
'dwconv.dwconv'
in
new_k
:
new_k
=
new_k
.
replace
(
'dwconv.dwconv'
,
'dwconv'
)
else
:
new_k
=
k
if
not
new_k
.
startswith
(
'head'
):
new_k
=
'backbone.'
+
new_k
new_ckpt
[
new_k
]
=
new_v
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in pretrained van '
'models to mmpretrain style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
else
:
state_dict
=
checkpoint
weight
=
convert_van
(
state_dict
)
mmengine
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
print
(
'Done!!'
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/vgg_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
from
collections
import
OrderedDict
import
torch
def
get_layer_maps
(
layer_num
,
with_bn
):
layer_maps
=
{
'conv'
:
{},
'bn'
:
{}}
if
with_bn
:
if
layer_num
==
11
:
layer_idxs
=
[
0
,
4
,
8
,
11
,
15
,
18
,
22
,
25
]
elif
layer_num
==
13
:
layer_idxs
=
[
0
,
3
,
7
,
10
,
14
,
17
,
21
,
24
,
28
,
31
]
elif
layer_num
==
16
:
layer_idxs
=
[
0
,
3
,
7
,
10
,
14
,
17
,
20
,
24
,
27
,
30
,
34
,
37
,
40
]
elif
layer_num
==
19
:
layer_idxs
=
[
0
,
3
,
7
,
10
,
14
,
17
,
20
,
23
,
27
,
30
,
33
,
36
,
40
,
43
,
46
,
49
]
else
:
raise
ValueError
(
f
'Invalid number of layers:
{
layer_num
}
'
)
for
i
,
layer_idx
in
enumerate
(
layer_idxs
):
if
i
==
0
:
new_layer_idx
=
layer_idx
else
:
new_layer_idx
+=
int
((
layer_idx
-
layer_idxs
[
i
-
1
])
/
2
)
layer_maps
[
'conv'
][
layer_idx
]
=
new_layer_idx
layer_maps
[
'bn'
][
layer_idx
+
1
]
=
new_layer_idx
else
:
if
layer_num
==
11
:
layer_idxs
=
[
0
,
3
,
6
,
8
,
11
,
13
,
16
,
18
]
new_layer_idxs
=
[
0
,
2
,
4
,
5
,
7
,
8
,
10
,
11
]
elif
layer_num
==
13
:
layer_idxs
=
[
0
,
2
,
5
,
7
,
10
,
12
,
15
,
17
,
20
,
22
]
new_layer_idxs
=
[
0
,
1
,
3
,
4
,
6
,
7
,
9
,
10
,
12
,
13
]
elif
layer_num
==
16
:
layer_idxs
=
[
0
,
2
,
5
,
7
,
10
,
12
,
14
,
17
,
19
,
21
,
24
,
26
,
28
]
new_layer_idxs
=
[
0
,
1
,
3
,
4
,
6
,
7
,
8
,
10
,
11
,
12
,
14
,
15
,
16
]
elif
layer_num
==
19
:
layer_idxs
=
[
0
,
2
,
5
,
7
,
10
,
12
,
14
,
16
,
19
,
21
,
23
,
25
,
28
,
30
,
32
,
34
]
new_layer_idxs
=
[
0
,
1
,
3
,
4
,
6
,
7
,
8
,
9
,
11
,
12
,
13
,
14
,
16
,
17
,
18
,
19
]
else
:
raise
ValueError
(
f
'Invalid number of layers:
{
layer_num
}
'
)
layer_maps
[
'conv'
]
=
{
layer_idx
:
new_layer_idx
for
layer_idx
,
new_layer_idx
in
zip
(
layer_idxs
,
new_layer_idxs
)
}
return
layer_maps
def
convert
(
src
,
dst
,
layer_num
,
with_bn
=
False
):
"""Convert keys in torchvision pretrained VGG models to mmpretrain
style."""
# load pytorch model
assert
os
.
path
.
isfile
(
src
),
f
'no checkpoint found at
{
src
}
'
blobs
=
torch
.
load
(
src
,
map_location
=
'cpu'
)
# convert to pytorch style
state_dict
=
OrderedDict
()
layer_maps
=
get_layer_maps
(
layer_num
,
with_bn
)
prefix
=
'backbone'
delimiter
=
'.'
for
key
,
weight
in
blobs
.
items
():
if
'features'
in
key
:
module
,
layer_idx
,
weight_type
=
key
.
split
(
delimiter
)
new_key
=
delimiter
.
join
([
prefix
,
key
])
layer_idx
=
int
(
layer_idx
)
for
layer_key
,
maps
in
layer_maps
.
items
():
if
layer_idx
in
maps
:
new_layer_idx
=
maps
[
layer_idx
]
new_key
=
delimiter
.
join
([
prefix
,
'features'
,
str
(
new_layer_idx
),
layer_key
,
weight_type
])
state_dict
[
new_key
]
=
weight
print
(
f
'Convert
{
key
}
to
{
new_key
}
'
)
elif
'classifier'
in
key
:
new_key
=
delimiter
.
join
([
prefix
,
key
])
state_dict
[
new_key
]
=
weight
print
(
f
'Convert
{
key
}
to
{
new_key
}
'
)
else
:
state_dict
[
key
]
=
weight
# save checkpoint
checkpoint
=
dict
()
checkpoint
[
'state_dict'
]
=
state_dict
torch
.
save
(
checkpoint
,
dst
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert model keys'
)
parser
.
add_argument
(
'src'
,
help
=
'src torchvision model path'
)
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
parser
.
add_argument
(
'--bn'
,
action
=
'store_true'
,
help
=
'whether original vgg has BN'
)
parser
.
add_argument
(
'--layer-num'
,
type
=
int
,
choices
=
[
11
,
13
,
16
,
19
],
default
=
11
,
help
=
'number of VGG layers'
)
args
=
parser
.
parse_args
()
convert
(
args
.
src
,
args
.
dst
,
layer_num
=
args
.
layer_num
,
with_bn
=
args
.
bn
)
if
__name__
==
'__main__'
:
main
()
tools/model_converters/vig_to_mmpretrain.py
0 → 100644
View file @
1ac2e802
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
import
re
from
collections
import
OrderedDict
import
mmengine
import
torch
from
mmengine.runner
import
CheckpointLoader
def
convert_vig
(
ckpt
):
new_ckpt
=
OrderedDict
()
for
k
,
v
in
ckpt
.
items
():
new_key
=
k
new_value
=
v
if
'pos_embed'
in
new_key
:
new_key
=
new_key
.
replace
(
'pos_embed'
,
'backbone.pos_embed'
)
elif
'stem'
in
new_key
:
new_key
=
new_key
.
replace
(
'stem.convs'
,
'backbone.stem'
)
elif
'backbone'
in
new_key
:
new_key
=
new_key
.
replace
(
'backbone'
,
'backbone.blocks'
)
elif
'prediction.0'
in
new_key
:
new_key
=
new_key
.
replace
(
'prediction.0'
,
'head.fc1'
)
new_value
=
v
.
squeeze
(
-
1
).
squeeze
(
-
1
)
elif
'prediction.1'
in
new_key
:
new_key
=
new_key
.
replace
(
'prediction.1'
,
'head.bn'
)
elif
'prediction.4'
in
new_key
:
new_key
=
new_key
.
replace
(
'prediction.4'
,
'head.fc2'
)
new_value
=
v
.
squeeze
(
-
1
).
squeeze
(
-
1
)
new_ckpt
[
new_key
]
=
new_value
return
new_ckpt
def
convert_pvig
(
ckpt
):
new_ckpt
=
OrderedDict
()
stage_idx
=
0
stage_blocks
=
0
for
k
,
v
in
ckpt
.
items
():
new_key
:
str
=
k
new_value
=
v
if
'pos_embed'
in
new_key
:
new_key
=
new_key
.
replace
(
'pos_embed'
,
'backbone.pos_embed'
)
elif
'stem'
in
new_key
:
new_key
=
new_key
.
replace
(
'stem.convs'
,
'backbone.stem'
)
elif
re
.
match
(
r
'^backbone\.\d+\.conv'
,
new_key
)
is
not
None
:
if
new_key
.
endswith
(
'0.weight'
):
stage_idx
+=
1
stage_blocks
=
int
(
new_key
.
split
(
'.'
)[
1
])
other
=
new_key
.
split
(
'.'
,
maxsplit
=
3
)[
-
1
]
new_key
=
f
'backbone.stages.
{
stage_idx
}
.0.'
+
other
elif
'backbone'
in
new_key
:
block_idx
=
int
(
new_key
.
split
(
'.'
)[
1
])
-
stage_blocks
other
=
new_key
.
split
(
'.'
,
maxsplit
=
2
)[
-
1
]
new_key
=
f
'backbone.stages.
{
stage_idx
}
.
{
block_idx
}
.'
+
other
elif
'prediction.0'
in
new_key
:
new_key
=
new_key
.
replace
(
'prediction.0'
,
'head.fc1'
)
new_value
=
v
.
squeeze
(
-
1
).
squeeze
(
-
1
)
elif
'prediction.1'
in
new_key
:
new_key
=
new_key
.
replace
(
'prediction.1'
,
'head.bn'
)
elif
'prediction.4'
in
new_key
:
new_key
=
new_key
.
replace
(
'prediction.4'
,
'head.fc2'
)
new_value
=
v
.
squeeze
(
-
1
).
squeeze
(
-
1
)
new_ckpt
[
new_key
]
=
new_value
return
new_ckpt
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert keys in pretrained vig models to '
'mmpretrain style.'
)
parser
.
add_argument
(
'src'
,
help
=
'src model path or url'
)
# The dst path must be a full path of the new checkpoint.
parser
.
add_argument
(
'dst'
,
help
=
'save path'
)
args
=
parser
.
parse_args
()
checkpoint
=
CheckpointLoader
.
load_checkpoint
(
args
.
src
,
map_location
=
'cpu'
)
if
'model'
in
checkpoint
:
state_dict
=
checkpoint
[
'model'
]
else
:
state_dict
=
checkpoint
if
'backbone.2.conv.0.weight'
in
state_dict
:
weight
=
convert_pvig
(
state_dict
)
else
:
weight
=
convert_vig
(
state_dict
)
mmengine
.
mkdir_or_exist
(
osp
.
dirname
(
args
.
dst
))
torch
.
save
(
weight
,
args
.
dst
)
print
(
'Done!!'
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment