Make data shuffling in `run_clm_flax.py` respect global seed (#13410)
* use jax and jnp instead of numpy in data_loader * return batches as np.ndarray
Showing
Please register or sign in to comment
* use jax and jnp instead of numpy in data_loader * return batches as np.ndarray