"docs/source/vscode:/vscode.git/clone" did not exist on "4124a09f8b3349f338917ad3282ca952bd15ec3a"
Unverified Commit 91fb62d0 authored by Yeb Havinga's avatar Yeb Havinga Committed by GitHub
Browse files

Speedup training by using numpy instead of jnp for batch shuffling (#15963)



Speedup training by using numpy instead of jnp for batch shuffling
Co-authored-by: default avatarYeb Havinga <y.t.havinga@mgrid.net>
parent ea07064a
...@@ -810,7 +810,7 @@ def main(): ...@@ -810,7 +810,7 @@ def main():
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"]) num_train_samples = len(tokenized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment