Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
934e21cd
Unverified
Commit
934e21cd
authored
May 11, 2022
by
Suraj Patil
Committed by
GitHub
May 11, 2022
Browse files
add shift_tokens_right in FlaxMT5 (#17188)
parent
47412c7d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
0 deletions
+15
-0
src/transformers/models/mt5/modeling_flax_mt5.py
src/transformers/models/mt5/modeling_flax_mt5.py
+15
-0
No files found.
src/transformers/models/mt5/modeling_flax_mt5.py
View file @
934e21cd
...
...
@@ -14,6 +14,8 @@
# limitations under the License.
""" Flax mT5 model."""
import
numpy
as
np
from
...utils
import
logging
from
..t5.modeling_flax_t5
import
FlaxT5ForConditionalGeneration
,
FlaxT5Model
from
.configuration_mt5
import
MT5Config
...
...
@@ -25,6 +27,19 @@ _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC
=
"T5Tokenizer"
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def
shift_tokens_right
(
input_ids
:
np
.
array
,
pad_token_id
:
int
,
decoder_start_token_id
:
int
)
->
np
.
ndarray
:
"""
Shift input ids one token to the right.
"""
shifted_input_ids
=
np
.
zeros_like
(
input_ids
)
shifted_input_ids
[:,
1
:]
=
input_ids
[:,
:
-
1
]
shifted_input_ids
[:,
0
]
=
decoder_start_token_id
shifted_input_ids
=
np
.
where
(
shifted_input_ids
==
-
100
,
pad_token_id
,
shifted_input_ids
)
return
shifted_input_ids
class
FlaxMT5Model
(
FlaxT5Model
):
r
"""
This class overrides [`FlaxT5Model`]. Please check the superclass for the appropriate documentation alongside usage
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment