Unverified Commit 40618ec2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF_MASKED_LM_SAMPLE (#16698)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 1471857f
...@@ -723,9 +723,10 @@ TF_MASKED_LM_SAMPLE = r""" ...@@ -723,9 +723,10 @@ TF_MASKED_LM_SAMPLE = r"""
>>> logits = model(**inputs).logits >>> logits = model(**inputs).logits
>>> # retrieve index of {mask} >>> # retrieve index of {mask}
>>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1] >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
>>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
>>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1) >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
>>> tokenizer.decode(predicted_token_id) >>> tokenizer.decode(predicted_token_id)
{expected_output} {expected_output}
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment