Unverified Commit 91fb62d0 authored by Yeb Havinga's avatar Yeb Havinga Committed by GitHub
Browse files

Speedup training by using numpy instead of jnp for batch shuffling (#15963)



Speedup training by using numpy instead of jnp for batch shuffling
Co-authored-by: default avatarYeb Havinga <y.t.havinga@mgrid.net>
parent ea07064a
...@@ -810,7 +810,7 @@ def main(): ...@@ -810,7 +810,7 @@ def main():
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"]) num_train_samples = len(tokenized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment