Unverified Commit 6c9010ef authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Update README.md

parent fde90187
......@@ -780,7 +780,8 @@ def cross_entropy(logits, labels):
# define a function which will run the forward pass return loss
def compute_loss(params, input_ids, labels):
logits = model(input_ids, params=params, train=True)
loss = cross_entropy(logits, onehot(labels)).mean()
num_classes = logits.shape[-1]
loss = cross_entropy(logits, onehot(labels, num_classes)).mean()
return loss
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment