[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
Showing
Please register or sign in to comment
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util