"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "0678a8603d2477d029b12b8db0a83e2084db053e"
Unverified Commit 5e686757 authored by Ahmed Elnaggar's avatar Ahmed Elnaggar Committed by GitHub
Browse files

Fix t5 shard on TPU Pods (#16527)



* Fix t5 shard on TPU Pods

The current script doesn't work properly on a TPU pod because the global batch is not divided correctly per host.
This pull request fixes this issue by dividing the global batch to each host before it is shared on each host.

* fix style
Co-authored-by: default avatarahmed-elnaggar <ahmed.elnaggar@allianz.com>
parent 2831826b
...@@ -746,6 +746,9 @@ def main(): ...@@ -746,6 +746,9 @@ def main():
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
num_of_hosts = jax.process_count()
current_host_idx = jax.process_index()
# Create learning rate schedule # Create learning rate schedule
warmup_fn = optax.linear_schedule( warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
...@@ -861,8 +864,13 @@ def main(): ...@@ -861,8 +864,13 @@ def main():
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples) model_inputs = data_collator(samples)
local_host_model_inputs = {
key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx]
for key, value in model_inputs.data.items()
}
# Model forward # Model forward
model_inputs = shard(model_inputs.data) model_inputs = shard(local_host_model_inputs)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric) train_metrics.append(train_metric)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment