"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "307a7d0be81f44734ed4ef4bc5613ac4d7f8bdec"
Unverified Commit 2a606f99 authored by Benjamin Minixhofer's avatar Benjamin Minixhofer Committed by GitHub
Browse files

Make data shuffling in `run_clm_flax.py` respect global seed (#13410)

* use jax and jnp instead of numpy in data_loader

* return batches as np.ndarray
parent 546a91ab
...@@ -253,9 +253,9 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf ...@@ -253,9 +253,9 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
steps_per_epoch = len(dataset) // batch_size steps_per_epoch = len(dataset) // batch_size
if shuffle: if shuffle:
batch_idx = np.random.permutation(len(dataset)) batch_idx = jax.random.permutation(rng, len(dataset))
else: else:
batch_idx = np.arange(len(dataset)) batch_idx = jnp.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment