Unverified Commit 10b76987 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[FlaxT5 Example] fix flax t5 example pretraining (#15835)

parent 01485cee
......@@ -368,7 +368,9 @@ class FlaxDataCollatorForT5MLM:
batch_size = input_ids.shape[0]
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
# masked tokens coming after sentinel tokens and should be removed
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
input_ids = np.concatenate(
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment