Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
443fdaf2
Unverified
Commit
443fdaf2
authored
Jan 05, 2022
by
Patrick von Platen
Committed by
GitHub
Jan 05, 2022
Browse files
[SpeechEncoderDecoder] Fix from pretrained (#15043)
parent
ae929dcb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
...speech_encoder_decoder/modeling_speech_encoder_decoder.py
+4
-4
No files found.
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
View file @
443fdaf2
...
@@ -380,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -380,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
)
if
"config"
not
in
kwargs_encoder
:
if
"config"
not
in
kwargs_encoder
:
encoder_config
=
AutoConfig
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
**
kwargs_encoder
)
encoder_config
=
AutoConfig
.
from_pretrained
(
encoder_pretrained_model_name_or_path
)
if
encoder_config
.
is_decoder
is
True
or
encoder_config
.
add_cross_attention
is
True
:
if
encoder_config
.
is_decoder
is
True
or
encoder_config
.
add_cross_attention
is
True
:
logger
.
info
(
logger
.
info
(
f
"Initializing
{
encoder_pretrained_model_name_or_path
}
as a encoder model "
f
"Initializing
{
encoder_pretrained_model_name_or_path
}
as a encoder model "
...
@@ -391,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -391,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
kwargs_encoder
[
"config"
]
=
encoder_config
kwargs_encoder
[
"config"
]
=
encoder_config
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
)
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
decoder
=
kwargs_decoder
.
pop
(
"model"
,
None
)
decoder
=
kwargs_decoder
.
pop
(
"model"
,
None
)
if
decoder
is
None
:
if
decoder
is
None
:
...
@@ -402,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -402,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
)
if
"config"
not
in
kwargs_decoder
:
if
"config"
not
in
kwargs_decoder
:
decoder_config
=
AutoConfig
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
)
decoder_config
=
AutoConfig
.
from_pretrained
(
decoder_pretrained_model_name_or_path
)
if
decoder_config
.
is_decoder
is
False
or
decoder_config
.
add_cross_attention
is
False
:
if
decoder_config
.
is_decoder
is
False
or
decoder_config
.
add_cross_attention
is
False
:
logger
.
info
(
logger
.
info
(
f
"Initializing
{
decoder_pretrained_model_name_or_path
}
as a decoder model. "
f
"Initializing
{
decoder_pretrained_model_name_or_path
}
as a decoder model. "
...
@@ -424,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
...
@@ -424,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
)
decoder
=
AutoModelForCausalLM
.
from_pretrained
(
decoder_pretrained_model_name_or_path
)
decoder
=
AutoModelForCausalLM
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
)
# instantiate config with corresponding kwargs
# instantiate config with corresponding kwargs
config
=
SpeechEncoderDecoderConfig
.
from_encoder_decoder_configs
(
encoder
.
config
,
decoder
.
config
,
**
kwargs
)
config
=
SpeechEncoderDecoderConfig
.
from_encoder_decoder_configs
(
encoder
.
config
,
decoder
.
config
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment