Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c448c01f
Unverified
Commit
c448c01f
authored
May 03, 2021
by
Patrick von Platen
Committed by
GitHub
May 03, 2021
Browse files
[Wav2Vec2] Fix convert (#11562)
* push * small change * correct other typo
parent
623281aa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
...onvert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
+5
-4
No files found.
src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
View file @
c448c01f
...
@@ -178,9 +178,11 @@ def convert_wav2vec2_checkpoint(
...
@@ -178,9 +178,11 @@ def convert_wav2vec2_checkpoint(
if
dict_path
:
if
dict_path
:
target_dict
=
Dictionary
.
load
(
dict_path
)
target_dict
=
Dictionary
.
load
(
dict_path
)
config
.
bos_token_id
=
target_dict
.
bos_index
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
config
.
bos_token_id
=
target_dict
.
pad_index
config
.
pad_token_id
=
target_dict
.
bos_index
config
.
eos_token_id
=
target_dict
.
eos_index
config
.
eos_token_id
=
target_dict
.
eos_index
config
.
pad_token_id
=
target_dict
.
pad_index
config
.
vocab_size
=
len
(
target_dict
.
symbols
)
config
.
vocab_size
=
len
(
target_dict
.
symbols
)
vocab_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
"vocab.json"
)
vocab_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
"vocab.json"
)
if
not
os
.
path
.
isdir
(
pytorch_dump_folder_path
):
if
not
os
.
path
.
isdir
(
pytorch_dump_folder_path
):
...
@@ -214,9 +216,8 @@ def convert_wav2vec2_checkpoint(
...
@@ -214,9 +216,8 @@ def convert_wav2vec2_checkpoint(
hf_wav2vec
=
Wav2Vec2Model
(
config
)
hf_wav2vec
=
Wav2Vec2Model
(
config
)
if
is_finetuned
:
if
is_finetuned
:
model
,
_
,
_
=
fairseq
.
checkpoint_utils
.
load_model_ensemble_and_task
(
model
,
_
,
_
=
fairseq
.
checkpoint_utils
.
load_model_ensemble_and_task
(
[
checkpoint_path
],
arg_overrides
=
{
"data"
:
dict_path
}
[
checkpoint_path
],
arg_overrides
=
{
"data"
:
"/"
.
join
(
dict_path
.
split
(
"/"
)[:
-
1
])
}
)
)
else
:
else
:
model
,
_
,
_
=
fairseq
.
checkpoint_utils
.
load_model_ensemble_and_task
([
checkpoint_path
])
model
,
_
,
_
=
fairseq
.
checkpoint_utils
.
load_model_ensemble_and_task
([
checkpoint_path
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment