Unverified Commit 8244c5ad authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Correct shift labels for seq2seq models in Flax (#12720)

* fix_torch_device_generate_test

* remove @

* push

* fix marian

* fix

* up
parent 1a3deae8
...@@ -19,6 +19,8 @@ import random ...@@ -19,6 +19,8 @@ import random
from functools import partial from functools import partial
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -212,15 +214,15 @@ BART_DECODE_INPUTS_DOCSTRING = r""" ...@@ -212,15 +214,15 @@ BART_DECODE_INPUTS_DOCSTRING = r"""
""" """
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) shifted_input_ids[:, 1:] = input_ids[:, :-1]
# replace possible -100 values in labels by `pad_token_id` shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids return shifted_input_ids
......
...@@ -221,11 +221,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_ ...@@ -221,11 +221,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) shifted_input_ids[:, 1:] = input_ids[:, :-1]
# replace possible -100 values in labels by `pad_token_id` shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids return shifted_input_ids
......
...@@ -19,6 +19,8 @@ import random ...@@ -19,6 +19,8 @@ import random
from functools import partial from functools import partial
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -217,20 +219,19 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray ...@@ -217,20 +219,19 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
have a single `decoder_start_token_id` in contrast to other Bart-like models. have a single `decoder_start_token_id` in contrast to other Bart-like models.
""" """
prev_output_tokens = jnp.array(input_ids).clone() prev_output_tokens = np.array(input_ids).copy()
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids) prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1) index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = jnp.array( decoder_start_tokens = np.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)] [prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32
).squeeze() ).squeeze()
# for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
for i in range(prev_output_tokens.shape[1], 0, -1): prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy()
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., i), prev_output_tokens[:, i - 1]) prev_output_tokens[:, 0] = decoder_start_tokens
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., 0), decoder_start_tokens)
return prev_output_tokens return prev_output_tokens
......
...@@ -47,15 +47,16 @@ _CONFIG_FOR_DOC = "T5Config" ...@@ -47,15 +47,16 @@ _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer" _TOKENIZER_FOR_DOC = "T5Tokenizer"
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) shifted_input_ids[:, 1:] = input_ids[:, :-1]
# replace possible -100 values in labels by `pad_token_id` shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids return shifted_input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment