Unverified Commit 7f3dab39 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: serializable hubert (#20966)

* serializable hubert
parent e5dcceb8
...@@ -221,13 +221,17 @@ def _compute_mask_indices( ...@@ -221,13 +221,17 @@ def _compute_mask_indices(
if mask_length < 1: if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.") raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length: tf.debugging.assert_less(
raise ValueError( mask_length,
sequence_length,
message=(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and" f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`" f" `sequence_length`: {sequence_length}`"
),
) )
# compute number of masked spans in batch # compute number of masked spans in batch
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,)) num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks) num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32) num_masked_spans = tf.cast(num_masked_spans, tf.int32)
......
...@@ -262,13 +262,17 @@ def _compute_mask_indices( ...@@ -262,13 +262,17 @@ def _compute_mask_indices(
if mask_length < 1: if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.") raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length: tf.debugging.assert_less(
raise ValueError( mask_length,
sequence_length,
message=(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and" f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`" f" `sequence_length`: {sequence_length}`"
),
) )
# compute number of masked spans in batch # compute number of masked spans in batch
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,)) num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks) num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32) num_masked_spans = tf.cast(num_masked_spans, tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment