Unverified Commit 32525428 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix doctest CI (#21166)



* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8ad06b7c
...@@ -256,7 +256,7 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" ...@@ -256,7 +256,7 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
>>> with torch.no_grad(): >>> with torch.no_grad():
... logits = model(**inputs).logits ... logits = model(**inputs).logits
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze() > 0.5] >>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label) >>> num_labels = len(model.config.id2label)
...@@ -264,7 +264,9 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" ...@@ -264,7 +264,9 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification" ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
... ) ... )
>>> labels = torch.nn.functional.one_hot(torch.tensor(predicted_class_ids), num_classes=num_labels).to(torch.float) >>> labels = torch.sum(
... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
... ).to(torch.float)
>>> loss = model(**inputs, labels=labels).loss >>> loss = model(**inputs, labels=labels).loss
``` ```
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment