Unverified Commit 4475f1dc authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Fix BigBird (#13380)

* finish

* finish
parent ecd53971
......@@ -2029,6 +2029,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
if token_type_ids is None:
token_type_ids = (~logits_mask).astype("i4")
logits_mask = jnp.expand_dims(logits_mask, axis=2)
logits_mask = jax.ops.index_update(logits_mask, jax.ops.index[:, 0], False)
# init input tensors if not passed
if token_type_ids is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment