"examples/vscode:/vscode.git/clone" did not exist on "157c9011d87e52632113024c1dc5125426971556"
Unverified Commit 934e21cd authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

add shift_tokens_right in FlaxMT5 (#17188)

parent 47412c7d
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" Flax mT5 model.""" """ Flax mT5 model."""
import numpy as np
from ...utils import logging from ...utils import logging
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from .configuration_mt5 import MT5Config from .configuration_mt5 import MT5Config
...@@ -25,6 +27,19 @@ _CONFIG_FOR_DOC = "T5Config" ...@@ -25,6 +27,19 @@ _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer" _TOKENIZER_FOR_DOC = "T5Tokenizer"
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
class FlaxMT5Model(FlaxT5Model): class FlaxMT5Model(FlaxT5Model):
r""" r"""
This class overrides [`FlaxT5Model`]. Please check the superclass for the appropriate documentation alongside usage This class overrides [`FlaxT5Model`]. Please check the superclass for the appropriate documentation alongside usage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment