Unverified Commit 7f3dab39 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: serializable hubert (#20966)

* serializable hubert
parent e5dcceb8
......@@ -221,13 +221,17 @@ def _compute_mask_indices(
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
tf.debugging.assert_less(
mask_length,
sequence_length,
message=(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`"
),
)
# compute number of masked spans in batch
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)
......
......@@ -262,13 +262,17 @@ def _compute_mask_indices(
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
tf.debugging.assert_less(
mask_length,
sequence_length,
message=(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`"
),
)
# compute number of masked spans in batch
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment