Unverified Commit 904e7ba7 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Fix the JAX/Example in README.md. (#603)



Fix JAX/Exmaples in README.md
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 3c04c417
......@@ -102,6 +102,7 @@ Flax
.. code-block:: python
import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
......@@ -130,7 +131,7 @@ Flax
# Initialize models.
variables = model.init(init_rng, inp)
other_variables, params = variables.pop('params')
other_variables, params = flax.core.pop(variables, 'params')
# Construct the forward and backward function
fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment