"examples/flax/vscode:/vscode.git/clone" did not exist on "f497f564bb76697edab09184a252fc1b1a326d1e"
Unverified Commit 32525428 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix doctest CI (#21166)



* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8ad06b7c
......@@ -256,7 +256,7 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze() > 0.5]
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
......@@ -264,7 +264,9 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
... )
>>> labels = torch.nn.functional.one_hot(torch.tensor(predicted_class_ids), num_classes=num_labels).to(torch.float)
>>> labels = torch.sum(
... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
... ).to(torch.float)
>>> loss = model(**inputs, labels=labels).loss
```
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment