Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
40618ec2
Unverified
Commit
40618ec2
authored
Apr 11, 2022
by
Yih-Dar
Committed by
GitHub
Apr 11, 2022
Browse files
Fix TF_MASKED_LM_SAMPLE (#16698)
Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
1471857f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
src/transformers/utils/doc.py
src/transformers/utils/doc.py
+3
-2
No files found.
src/transformers/utils/doc.py
View file @
40618ec2
...
@@ -723,9 +723,10 @@ TF_MASKED_LM_SAMPLE = r"""
...
@@ -723,9 +723,10 @@ TF_MASKED_LM_SAMPLE = r"""
>>> logits = model(**inputs).logits
>>> logits = model(**inputs).logits
>>> # retrieve index of {mask}
>>> # retrieve index of {mask}
>>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
>>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
>>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
>>> predicted_token_id = tf.math.argmax(
logits[0, mask_token_index]
, axis=-1)
>>> predicted_token_id = tf.math.argmax(
selected_logits
, axis=-1)
>>> tokenizer.decode(predicted_token_id)
>>> tokenizer.decode(predicted_token_id)
{expected_output}
{expected_output}
```
```
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment