Commit 05a95e16 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Add in a comment about legacy input masks for XLNet.

PiperOrigin-RevId: 339267061
parent 4437d7b4
......@@ -795,6 +795,8 @@ class PretrainingXLNetModel(tf.keras.Model):
masked_tokens = features["input_q"]
seg_ids = features["seg_id"]
if self._use_legacy_mask:
# Legacy input mask assumes `real` values are 0 and `padding`
# values are 1.
perm_mask = 1 - features["perm_mask"]
else:
perm_mask = features["perm_mask"]
......@@ -885,6 +887,8 @@ class ClassificationXLNetModel(tf.keras.Model):
input_ids = features["input_ids"]
segment_ids = features["segment_ids"]
if self._use_legacy_mask:
# Legacy input mask assumes `real` values are 0 and `padding`
# values are 1.
input_mask = 1 - features["input_mask"]
else:
input_mask = features["input_mask"]
......@@ -1130,6 +1134,8 @@ class QAXLNetModel(tf.keras.Model):
input_ids = features["input_ids"]
segment_ids = features["segment_ids"]
if self._use_legacy_mask:
# Legacy input mask assumes `real` values are 0 and `padding`
# values are 1.
input_mask = 1 - features["input_mask"]
else:
input_mask = features["input_mask"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment