Ensure values passed to jax.numpy functions are arrays rather than lists.
Why? This will soon be a requirement in JAX; see https://github.com/google/jax/issues/7737 PiperOrigin-RevId: 394105499 Change-Id: I6f76e3f47b7906616c1f988be58d29c828414060
Showing
Please register or sign in to comment