Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ee4fb326
"vscode:/vscode.git/clone" did not exist on "f5e8c9bdeab96c3426583cf1aa572ce7ede8a070"
Unverified
Commit
ee4fb326
authored
Nov 14, 2023
by
Yoach Lacombe
Committed by
GitHub
Nov 14, 2023
Browse files
Fix M4T weights tying (#27395)
fix seamless m4t weights tying
parent
e107ae36
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
0 deletions
+23
-0
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
...transformers/models/seamless_m4t/modeling_seamless_m4t.py
+23
-0
No files found.
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
View file @
ee4fb326
...
@@ -2785,6 +2785,12 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
...
@@ -2785,6 +2785,12 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
self
.
text_decoder
.
embed_tokens
=
value
self
.
text_decoder
.
embed_tokens
=
value
self
.
shared
=
value
self
.
shared
=
value
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
text_encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
text_decoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
shared
)
@
add_start_docstrings_to_model_forward
(
M4T_TEXT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
M4T_TEXT_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -3077,6 +3083,11 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
...
@@ -3077,6 +3083,11 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
value
):
self
.
text_decoder
.
embed_tokens
=
value
self
.
text_decoder
.
embed_tokens
=
value
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
text_decoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
shared
)
@
add_start_docstrings_to_model_forward
(
M4T_SPEECH_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
M4T_SPEECH_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -3384,6 +3395,12 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
...
@@ -3384,6 +3395,12 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
self
.
text_decoder
.
embed_tokens
=
value
self
.
text_decoder
.
embed_tokens
=
value
self
.
shared
=
value
self
.
shared
=
value
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
text_encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
text_decoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
shared
)
@
add_start_docstrings_to_model_forward
(
M4T_TEXT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
M4T_TEXT_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -3740,6 +3757,11 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel):
...
@@ -3740,6 +3757,11 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel):
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
value
):
self
.
text_decoder
.
embed_tokens
=
value
self
.
text_decoder
.
embed_tokens
=
value
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
text_decoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
shared
)
@
add_start_docstrings_to_model_forward
(
M4T_SPEECH_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
M4T_SPEECH_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -4135,6 +4157,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
...
@@ -4135,6 +4157,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
text_encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
text_encoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
text_decoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
text_decoder
.
embed_tokens
,
self
.
shared
)
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
shared
)
@
add_start_docstrings_to_model_forward
(
M4T_MODEL_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
M4T_MODEL_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment