Unverified Commit 1ab147d6 authored by Nicholas Vadivelu's avatar Nicholas Vadivelu Committed by GitHub
Browse files

Remove redundant `nn.log_softmax` in `run_flax_glue.py` (#11920)

* Remove redundant `nn.log_softmax` in `run_flax_glue.py`

`optax.softmax_cross_entropy` expects unnormalized logits, and so it already calls `nn.log_softmax`, so I believe it is not needed here. `nn.log_softmax` is idempotent so mathematically it shouldn't have made a difference.

* Remove unused 'flax.linen' import
parent fb60c309
......@@ -29,7 +29,6 @@ import jax
import jax.numpy as jnp
import optax
import transformers
from flax import linen as nn
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
......@@ -202,7 +201,6 @@ def create_train_state(
else: # Classification.
def cross_entropy_loss(logits, labels):
logits = nn.log_softmax(logits)
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
return jnp.mean(xentropy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment