Unverified Commit 7d6285a9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Flax - Adapt wav2vec2 script (#12520)

* fix_torch_device_generate_test

* remove @

* adapt flax pretrain script
parent 4605b2b8
...@@ -64,6 +64,12 @@ class ModelArguments: ...@@ -64,6 +64,12 @@ class ModelArguments:
gumbel_temperature_decay: Optional[float] = field( gumbel_temperature_decay: Optional[float] = field(
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."} default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
) )
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@flax.struct.dataclass @flax.struct.dataclass
...@@ -197,7 +203,7 @@ def configure_logger(model_args: ModelArguments, training_args: TrainingArgument ...@@ -197,7 +203,7 @@ def configure_logger(model_args: ModelArguments, training_args: TrainingArgument
logger.setLevel(logging_level) logger.setLevel(logging_level)
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step) summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics) train_metrics = get_metrics(train_metrics)
...@@ -206,6 +212,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -206,6 +212,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
for i, val in enumerate(vals): for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items(): for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
...@@ -342,9 +350,7 @@ def main(): ...@@ -342,9 +350,7 @@ def main():
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
) )
model = FlaxWav2Vec2ForPreTraining( model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
data_collator = FlaxDataCollatorForWav2Vec2Pretraining( data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
...@@ -501,11 +507,11 @@ def main(): ...@@ -501,11 +507,11 @@ def main():
state = jax_utils.replicate(state) state = jax_utils.replicate(state)
train_time = 0 train_time = 0
train_metrics = []
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs: for epoch in epochs:
# ======================== Training ================================ # ======================== Training ================================
train_start = time.time() train_start = time.time()
train_metrics = []
# Create sampling rng # Create sampling rng
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
...@@ -516,7 +522,7 @@ def main(): ...@@ -516,7 +522,7 @@ def main():
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx] samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples) model_inputs = data_collator(samples)
model_inputs = shard(model_inputs.data) model_inputs = shard(model_inputs.data)
...@@ -527,11 +533,20 @@ def main(): ...@@ -527,11 +533,20 @@ def main():
) )
train_metrics.append(train_metric) train_metrics.append(train_metric)
train_time += time.time() - train_start cur_step = epoch * (num_train_samples // train_batch_size) + step
epochs.write( if cur_step % training_args.logging_steps == 0 and cur_step > 0:
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" # Save metrics
) train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
)
train_metrics = []
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(vectorized_datasets["validation"]) num_eval_samples = len(vectorized_datasets["validation"])
...@@ -560,7 +575,7 @@ def main(): ...@@ -560,7 +575,7 @@ def main():
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment