"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "36ee128375e41bbed0bf2aa4aa4862af8dc9e908"
Unverified Commit bb4ac2b5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Correct flax training scripts (#12514)

* fix_torch_device_generate_test

* remove @

* add logging steps

* correct training scripts

* correct readme

* correct
parent ea556750
...@@ -137,10 +137,10 @@ Next we can run the example script to pretrain the model: ...@@ -137,10 +137,10 @@ Next we can run the example script to pretrain the model:
--learning_rate="3e-4" \ --learning_rate="3e-4" \
--warmup_steps="1000" \ --warmup_steps="1000" \
--overwrite_output_dir \ --overwrite_output_dir \
--pad_to_max_length \
--num_train_epochs="18" \ --num_train_epochs="18" \
--adam_beta1="0.9" \ --adam_beta1="0.9" \
--adam_beta2="0.98" \ --adam_beta2="0.98" \
--logging_steps="500" \
--push_to_hub --push_to_hub
``` ```
...@@ -233,6 +233,7 @@ Next we can run the example script to pretrain the model: ...@@ -233,6 +233,7 @@ Next we can run the example script to pretrain the model:
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="20" \ --num_train_epochs="20" \
--logging_steps="500" \
--push_to_hub --push_to_hub
``` ```
...@@ -368,6 +369,7 @@ Next we can run the example script to pretrain the model: ...@@ -368,6 +369,7 @@ Next we can run the example script to pretrain the model:
--warmup_steps="5000" \ --warmup_steps="5000" \
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="10" \ --num_train_epochs="10" \
--logging_steps="500" \
--push_to_hub --push_to_hub
``` ```
......
...@@ -57,22 +57,6 @@ from transformers.testing_utils import CaptureLogger ...@@ -57,22 +57,6 @@ from transformers.testing_utils import CaptureLogger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Cache the result
has_tensorboard = is_tensorboard_available()
if has_tensorboard:
try:
from flax.metrics.tensorboard import SummaryWriter
except ImportError as ie:
has_tensorboard = False
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
else:
print(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
...@@ -214,7 +198,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf ...@@ -214,7 +198,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
yield batch yield batch
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step) summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics) train_metrics = get_metrics(train_metrics)
...@@ -223,6 +207,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -223,6 +207,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
for i, val in enumerate(vals): for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items(): for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
...@@ -450,8 +436,22 @@ def main(): ...@@ -450,8 +436,22 @@ def main():
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# Enable tensorboard only on the master node # Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training # Initialize our training
rng = jax.random.PRNGKey(training_args.seed) rng = jax.random.PRNGKey(training_args.seed)
...@@ -554,6 +554,7 @@ def main(): ...@@ -554,6 +554,7 @@ def main():
logger.info(f" Total optimization steps = {total_train_steps}") logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0 train_time = 0
train_metrics = []
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs: for epoch in epochs:
# ======================== Training ================================ # ======================== Training ================================
...@@ -561,24 +562,30 @@ def main(): ...@@ -561,24 +562,30 @@ def main():
# Create sampling rng # Create sampling rng
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
train_metrics = []
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size steps_per_epoch = len(train_dataset) // train_batch_size
# train # train
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader) batch = next(train_loader)
state, train_metric = p_train_step(state, batch) state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric) train_metrics.append(train_metric)
train_time += time.time() - train_start cur_step = epoch * (len(train_dataset) // train_batch_size) + step
train_metric = unreplicate(train_metric) if cur_step % training_args.logging_steps and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write( epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
) )
train_metrics = []
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
eval_metrics = [] eval_metrics = []
...@@ -608,7 +615,7 @@ def main(): ...@@ -608,7 +615,7 @@ def main():
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(train_dataset) // train_batch_size) cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
......
...@@ -56,22 +56,6 @@ from transformers import ( ...@@ -56,22 +56,6 @@ from transformers import (
) )
# Cache the result
has_tensorboard = is_tensorboard_available()
if has_tensorboard:
try:
from flax.metrics.tensorboard import SummaryWriter
except ImportError as ie:
has_tensorboard = False
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
else:
print(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
...@@ -269,7 +253,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar ...@@ -269,7 +253,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
return batch_idx return batch_idx
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step) summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics) train_metrics = get_metrics(train_metrics)
...@@ -278,6 +262,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -278,6 +262,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
for i, val in enumerate(vals): for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items(): for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
...@@ -315,10 +301,6 @@ if __name__ == "__main__": ...@@ -315,10 +301,6 @@ if __name__ == "__main__":
# Log on each process the small summary: # Log on each process the small summary:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
...@@ -471,8 +453,22 @@ if __name__ == "__main__": ...@@ -471,8 +453,22 @@ if __name__ == "__main__":
) )
# Enable tensorboard only on the master node # Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Data collator # Data collator
# This one will take care of randomly masking the tokens. # This one will take care of randomly masking the tokens.
...@@ -601,7 +597,7 @@ if __name__ == "__main__": ...@@ -601,7 +597,7 @@ if __name__ == "__main__":
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16) model_inputs = data_collator(samples, pad_to_multiple_of=16)
...@@ -610,11 +606,20 @@ if __name__ == "__main__": ...@@ -610,11 +606,20 @@ if __name__ == "__main__":
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric) train_metrics.append(train_metric)
train_time += time.time() - train_start cur_step = epoch * num_train_samples + step
epochs.write( if cur_step % training_args.logging_steps and cur_step > 0:
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" # Save metrics
) train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
)
train_metrics = []
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
...@@ -645,7 +650,7 @@ if __name__ == "__main__": ...@@ -645,7 +650,7 @@ if __name__ == "__main__":
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
......
...@@ -382,7 +382,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar ...@@ -382,7 +382,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
return batch_idx return batch_idx
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step) summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics) train_metrics = get_metrics(train_metrics)
...@@ -391,6 +391,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -391,6 +391,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
for i, val in enumerate(vals): for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items(): for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
...@@ -711,7 +713,7 @@ if __name__ == "__main__": ...@@ -711,7 +713,7 @@ if __name__ == "__main__":
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step # Gather the indexes for creating the batch and do a training step
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples) model_inputs = data_collator(samples)
...@@ -720,11 +722,20 @@ if __name__ == "__main__": ...@@ -720,11 +722,20 @@ if __name__ == "__main__":
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric) train_metrics.append(train_metric)
train_time += time.time() - train_start cur_step = epoch * num_train_samples + step
epochs.write( if cur_step % training_args.logging_steps and cur_step > 0:
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" # Save metrics
) train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
)
train_metrics = []
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
...@@ -753,7 +764,7 @@ if __name__ == "__main__": ...@@ -753,7 +764,7 @@ if __name__ == "__main__":
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment