Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
32525428
Unverified
Commit
32525428
authored
Jan 18, 2023
by
Yih-Dar
Committed by
GitHub
Jan 18, 2023
Browse files
Fix doctest CI (#21166)
* fix Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
8ad06b7c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
src/transformers/utils/doc.py
src/transformers/utils/doc.py
+4
-2
No files found.
src/transformers/utils/doc.py
View file @
32525428
...
...
@@ -256,7 +256,7 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze() > 0.5]
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(
dim=0
) > 0.5]
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
...
...
@@ -264,7 +264,9 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
... )
>>> labels = torch.nn.functional.one_hot(torch.tensor(predicted_class_ids), num_classes=num_labels).to(torch.float)
>>> labels = torch.sum(
... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
... ).to(torch.float)
>>> loss = model(**inputs, labels=labels).loss
```
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment