"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "410b61ad7e8f69113a86d0003190e3c392c7c39a"
Unverified Commit 40618ec2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF_MASKED_LM_SAMPLE (#16698)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 1471857f
......@@ -723,9 +723,10 @@ TF_MASKED_LM_SAMPLE = r"""
>>> logits = model(**inputs).logits
>>> # retrieve index of {mask}
>>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
>>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
>>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
>>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
>>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
>>> tokenizer.decode(predicted_token_id)
{expected_output}
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment