Unverified Commit d61d7476 authored by amitportnoy's avatar amitportnoy Committed by GitHub
Browse files

Update trainer.mdx class_weights example (#23787)

class_weights tensor should follow model's device
parent 4d9b76a8
......@@ -61,7 +61,7 @@ class CustomTrainer(Trainer):
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment