Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
443fdaf2
Unverified
Commit
443fdaf2
authored
Jan 05, 2022
by
Patrick von Platen
Committed by
GitHub
Jan 05, 2022
Browse files
[SpeechEncoderDecoder] Fix from pretrained (#15043)
parent
ae929dcb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
...speech_encoder_decoder/modeling_speech_encoder_decoder.py
+4
-4
No files found.
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
View file @
443fdaf2
...
@@ -380,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -380,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
)
if
"config"
not
in
kwargs_encoder
:
if
"config"
not
in
kwargs_encoder
:
encoder_config
=
AutoConfig
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
**
kwargs_encoder
)
encoder_config
=
AutoConfig
.
from_pretrained
(
encoder_pretrained_model_name_or_path
)
if
encoder_config
.
is_decoder
is
True
or
encoder_config
.
add_cross_attention
is
True
:
if
encoder_config
.
is_decoder
is
True
or
encoder_config
.
add_cross_attention
is
True
:
logger
.
info
(
logger
.
info
(
f
"Initializing
{
encoder_pretrained_model_name_or_path
}
as a encoder model "
f
"Initializing
{
encoder_pretrained_model_name_or_path
}
as a encoder model "
...
@@ -391,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -391,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
kwargs_encoder
[
"config"
]
=
encoder_config
kwargs_encoder
[
"config"
]
=
encoder_config
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
)
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
decoder
=
kwargs_decoder
.
pop
(
"model"
,
None
)
decoder
=
kwargs_decoder
.
pop
(
"model"
,
None
)
if
decoder
is
None
:
if
decoder
is
None
:
...
@@ -402,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -402,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
)
if
"config"
not
in
kwargs_decoder
:
if
"config"
not
in
kwargs_decoder
:
decoder_config
=
AutoConfig
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
)
decoder_config
=
AutoConfig
.
from_pretrained
(
decoder_pretrained_model_name_or_path
)
if
decoder_config
.
is_decoder
is
False
or
decoder_config
.
add_cross_attention
is
False
:
if
decoder_config
.
is_decoder
is
False
or
decoder_config
.
add_cross_attention
is
False
:
logger
.
info
(
logger
.
info
(
f
"Initializing
{
decoder_pretrained_model_name_or_path
}
as a decoder model. "
f
"Initializing
{
decoder_pretrained_model_name_or_path
}
as a decoder model. "
...
@@ -424,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -424,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
)
decoder
=
AutoModelForCausalLM
.
from_pretrained
(
decoder_pretrained_model_name_or_path
)
decoder
=
AutoModelForCausalLM
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
)
# instantiate config with corresponding kwargs
# instantiate config with corresponding kwargs
config
=
SpeechEncoderDecoderConfig
.
from_encoder_decoder_configs
(
encoder
.
config
,
decoder
.
config
,
**
kwargs
)
config
=
SpeechEncoderDecoderConfig
.
from_encoder_decoder_configs
(
encoder
.
config
,
decoder
.
config
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment