Unverified Commit 8618bf15 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1736 from huggingface/fix-tf-xlnet

Fix TFXLNet
parents 2fa8737c dfb61caf
...@@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \ assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one." "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
if input_mask is None and attention_mask is not None: if input_mask is None and attention_mask is not None:
input_mask = 1.0 - attention_mask input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
if input_mask is not None and perm_mask is not None: if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None: elif input_mask is not None and perm_mask is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment