Unverified Commit 6c9010ef authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Update README.md

parent fde90187
...@@ -780,7 +780,8 @@ def cross_entropy(logits, labels): ...@@ -780,7 +780,8 @@ def cross_entropy(logits, labels):
# define a function which will run the forward pass return loss # define a function which will run the forward pass return loss
def compute_loss(params, input_ids, labels): def compute_loss(params, input_ids, labels):
logits = model(input_ids, params=params, train=True) logits = model(input_ids, params=params, train=True)
loss = cross_entropy(logits, onehot(labels)).mean() num_classes = logits.shape[-1]
loss = cross_entropy(logits, onehot(labels, num_classes)).mean()
return loss return loss
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment