Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7f07c356
Unverified
Commit
7f07c356
authored
Dec 08, 2023
by
Yoach Lacombe
Committed by
GitHub
Dec 08, 2023
Browse files
Fix CLAP converting script (#27153)
* update converting script * make style
parent
b31905d1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
16 deletions
+26
-16
src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py
...ormers/models/clap/convert_clap_original_pytorch_to_hf.py
+26
-16
No files found.
src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py
View file @
7f07c356
...
...
@@ -16,8 +16,7 @@
import
argparse
import
re
import
torch
from
CLAP
import
create_model
from
laion_clap
import
CLAP_Module
from
transformers
import
AutoFeatureExtractor
,
ClapConfig
,
ClapModel
...
...
@@ -38,17 +37,25 @@ KEYS_TO_MODIFY_MAPPING = {
processor
=
AutoFeatureExtractor
.
from_pretrained
(
"laion/clap-htsat-unfused"
,
truncation
=
"rand_trunc"
)
def
init_clap
(
checkpoint_path
,
enable_fusion
=
False
):
model
,
model_cfg
=
create_model
(
"HTSAT-tiny"
,
"roberta"
,
checkpoint_path
,
precision
=
"fp32"
,
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
,
def
init_clap
(
checkpoint_path
,
model_type
,
enable_fusion
=
False
):
model
=
CLAP_Module
(
amodel
=
model_type
,
enable_fusion
=
enable_fusion
,
fusion_type
=
"aff_2d"
if
enable_fusion
else
None
,
)
return
model
,
model_cfg
model
.
load_ckpt
(
checkpoint_path
)
return
model
def
get_config_from_original
(
clap_model
):
audio_config
=
{
"patch_embeds_hidden_size"
:
clap_model
.
model
.
audio_branch
.
embed_dim
,
"depths"
:
clap_model
.
model
.
audio_branch
.
depths
,
"hidden_size"
:
clap_model
.
model
.
audio_projection
[
0
].
in_features
,
}
text_config
=
{
"hidden_size"
:
clap_model
.
model
.
text_branch
.
pooler
.
dense
.
in_features
}
return
ClapConfig
(
audio_config
=
audio_config
,
text_config
=
text_config
)
def
rename_state_dict
(
state_dict
):
...
...
@@ -94,14 +101,14 @@ def rename_state_dict(state_dict):
return
model_state_dict
def
convert_clap_checkpoint
(
checkpoint_path
,
pytorch_dump_folder_path
,
config_path
,
enable_fusion
=
False
):
clap_model
,
clap_model_cfg
=
init_clap
(
checkpoint_path
,
enable_fusion
=
enable_fusion
)
def
convert_clap_checkpoint
(
checkpoint_path
,
pytorch_dump_folder_path
,
config_path
,
model_type
,
enable_fusion
=
False
):
clap_model
=
init_clap
(
checkpoint_path
,
model_type
,
enable_fusion
=
enable_fusion
)
clap_model
.
eval
()
state_dict
=
clap_model
.
state_dict
()
state_dict
=
clap_model
.
model
.
state_dict
()
state_dict
=
rename_state_dict
(
state_dict
)
transformers_config
=
ClapConfig
(
)
transformers_config
=
get_config_from_original
(
clap_model
)
transformers_config
.
audio_config
.
enable_fusion
=
enable_fusion
model
=
ClapModel
(
transformers_config
)
...
...
@@ -118,6 +125,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--checkpoint_path"
,
default
=
None
,
type
=
str
,
help
=
"Path to fairseq checkpoint"
)
parser
.
add_argument
(
"--config_path"
,
default
=
None
,
type
=
str
,
help
=
"Path to hf config.json of model to convert"
)
parser
.
add_argument
(
"--enable_fusion"
,
action
=
"store_true"
,
help
=
"Whether to enable fusion or not"
)
parser
.
add_argument
(
"--model_type"
,
default
=
"HTSAT-tiny"
,
type
=
str
,
help
=
"Whether to enable fusion or not"
)
args
=
parser
.
parse_args
()
convert_clap_checkpoint
(
args
.
checkpoint_path
,
args
.
pytorch_dump_folder_path
,
args
.
config_path
,
args
.
enable_fusion
)
convert_clap_checkpoint
(
args
.
checkpoint_path
,
args
.
pytorch_dump_folder_path
,
args
.
config_path
,
args
.
model_type
,
args
.
enable_fusion
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment