Unverified Commit dc991805 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix doc example (#16448)



* Fix doc

* Make fixup
Co-authored-by: default avatarNiels Rogge <nielsrogge@nielss-mbp.home>
parent febe42b5
......@@ -269,9 +269,10 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
```python
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
>>> model = {model_class}.from_pretrained(
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
... )
>>> num_labels = len(model.config.id2label)
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
... torch.float
... )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment