"git@developer.sourcefind.cn:modelzoo/donut_pytorch.git" did not exist on "fa5e69b19522f475e66ee6f2f918c8f5c6194722"
Unverified Commit 40618ec2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF_MASKED_LM_SAMPLE (#16698)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 1471857f
......@@ -723,9 +723,10 @@ TF_MASKED_LM_SAMPLE = r"""
>>> logits = model(**inputs).logits
>>> # retrieve index of {mask}
>>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
>>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
>>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
>>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
>>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
>>> tokenizer.decode(predicted_token_id)
{expected_output}
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment