Unverified Commit 1e8140ca authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

Fix RESOURCE_EXHAUSTED error when dealing with large datasets in Flax example scripts (#18069)

* Fix RESOURCE_EXHAUSTED error for large datasets on Flax example scripts

* using np.permutation for creating batch_idx

* train_samples_idx -> training_samples_idx

* fix type hints
parent ac98a88f
...@@ -326,7 +326,7 @@ class FlaxDataCollatorForLanguageModeling: ...@@ -326,7 +326,7 @@ class FlaxDataCollatorForLanguageModeling:
return inputs, labels return inputs, labels
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
num_samples = len(samples_idx) num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size samples_to_remove = num_samples % batch_size
...@@ -755,7 +755,8 @@ def main(): ...@@ -755,7 +755,8 @@ def main():
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"]) num_train_samples = len(tokenized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) # Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
...@@ -787,7 +788,8 @@ def main(): ...@@ -787,7 +788,8 @@ def main():
if cur_step % training_args.eval_steps == 0 and cur_step > 0: if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples) # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
...@@ -825,7 +827,8 @@ def main(): ...@@ -825,7 +827,8 @@ def main():
# Eval after training # Eval after training
if training_args.do_eval: if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples) # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
......
...@@ -459,7 +459,7 @@ class FlaxDataCollatorForT5MLM: ...@@ -459,7 +459,7 @@ class FlaxDataCollatorForT5MLM:
return is_noise[:orig_length] return is_noise[:orig_length]
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
num_samples = len(samples_idx) num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size samples_to_remove = num_samples % batch_size
...@@ -871,6 +871,7 @@ def main(): ...@@ -871,6 +871,7 @@ def main():
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"]) num_train_samples = len(tokenized_datasets["train"])
# Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
...@@ -908,7 +909,8 @@ def main(): ...@@ -908,7 +909,8 @@ def main():
if cur_step % training_args.eval_steps == 0 and cur_step > 0: if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples) # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
...@@ -944,7 +946,8 @@ def main(): ...@@ -944,7 +946,8 @@ def main():
# Eval after training # Eval after training
if training_args.do_eval: if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples) # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
......
...@@ -264,7 +264,7 @@ class FlaxDataCollatorForLanguageModeling: ...@@ -264,7 +264,7 @@ class FlaxDataCollatorForLanguageModeling:
return inputs, labels return inputs, labels
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
num_samples = len(samples_idx) num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size samples_to_remove = num_samples % batch_size
...@@ -592,7 +592,8 @@ if __name__ == "__main__": ...@@ -592,7 +592,8 @@ if __name__ == "__main__":
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
if step % training_args.eval_steps == 0 and step > 0: if step % training_args.eval_steps == 0 and step > 0:
eval_samples_idx = jnp.arange(data_args.num_eval_samples) # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(data_args.num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
......
...@@ -237,7 +237,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): ...@@ -237,7 +237,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
num_samples = len(samples_idx) num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size samples_to_remove = num_samples % batch_size
...@@ -541,7 +541,8 @@ def main(): ...@@ -541,7 +541,8 @@ def main():
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(vectorized_datasets["train"]) num_train_samples = len(vectorized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) # Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
...@@ -574,7 +575,8 @@ def main(): ...@@ -574,7 +575,8 @@ def main():
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(vectorized_datasets["validation"]) num_eval_samples = len(vectorized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples) # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
......
...@@ -433,7 +433,7 @@ def eval_step(params, batch): ...@@ -433,7 +433,7 @@ def eval_step(params, batch):
return compute_metrics(logits, targets, token_mask) return compute_metrics(logits, targets, token_mask)
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
nb_samples = len(samples_idx) nb_samples = len(samples_idx)
samples_to_remove = nb_samples % batch_size samples_to_remove = nb_samples % batch_size
...@@ -639,7 +639,8 @@ if __name__ == "__main__": ...@@ -639,7 +639,8 @@ if __name__ == "__main__":
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
nb_training_samples = len(tokenized_datasets["train"]) nb_training_samples = len(tokenized_datasets["train"])
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples)) # Avoid using jax.numpy here in case of TPU training
training_samples_idx = np.random.permutation(np.arange(nb_training_samples))
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size) training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
...@@ -658,7 +659,8 @@ if __name__ == "__main__": ...@@ -658,7 +659,8 @@ if __name__ == "__main__":
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
nb_eval_samples = len(tokenized_datasets["validation"]) nb_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(nb_eval_samples) # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(nb_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment