Unverified Commit c5bd732a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add Flax example tests (#14599)

* add test for glue

* add tests for clm

* fix clm test

* add summrization tests

* more tests

* fix few tests

* add test for t5 mlm

* fix t5 mlm test

* fix tests for multi device

* cleanup

* ci job

* fix metric file name

* make t5 more robust
parent 803a8cd1
...@@ -613,6 +613,69 @@ jobs: ...@@ -613,6 +613,69 @@ jobs:
- store_artifacts: - store_artifacts:
path: ~/transformers/reports path: ~/transformers/reports
run_examples_flax:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-flax_examples-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece]
- run: pip install -r examples/flax/_tests_requirements.txt
- save_cache:
key: v0.4-flax_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python utils/tests_fetcher.py --filters examples tests | tee test_preparation.txt
- store_artifacts:
path: ~/transformers/test_preparation.txt
- run: |
if [ -f test_list.txt ]; then
python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee tests_output.txt
fi
- store_artifacts:
path: ~/transformers/flax_examples_output.txt
- store_artifacts:
path: ~/transformers/reports
run_examples_flax_all:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-flax_examples-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: sudo pip install .[flax,testing,sentencepiece]
- run: pip install -r examples/flax/_tests_requirements.txt
- save_cache:
key: v0.4-flax_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: |
TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee examples_output.txt
- store_artifacts:
path: ~/transformers/flax_examples_output.txt
- store_artifacts:
path: ~/transformers/reports
run_tests_hub: run_tests_hub:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
......
datasets >= 1.1.3
pytest
conllu
nltk
rouge-score
seqeval
tensorboard
\ No newline at end of file
...@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=causal-lm ...@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=causal-lm
""" """
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import json
import logging import logging
import math import math
import os import os
...@@ -672,6 +673,32 @@ def main(): ...@@ -672,6 +673,32 @@ def main():
if training_args.push_to_hub: if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
# Eval after training
if training_args.do_eval:
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = shard(next(eval_loader))
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -20,7 +20,9 @@ text file or a dataset. ...@@ -20,7 +20,9 @@ text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script: Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm https://huggingface.co/models?filter=masked-lm
""" """
import json
import logging import logging
import math
import os import os
import sys import sys
import time import time
...@@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): ...@@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
if __name__ == "__main__": def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
...@@ -700,3 +702,41 @@ if __name__ == "__main__": ...@@ -700,3 +702,41 @@ if __name__ == "__main__":
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try:
perplexity = math.exp(eval_metrics["loss"])
except OverflowError:
perplexity = float("inf")
eval_metrics["perplexity"] = perplexity
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()
...@@ -20,6 +20,7 @@ Here is the full list of checkpoints on the hub that can be pretrained by this s ...@@ -20,6 +20,7 @@ Here is the full list of checkpoints on the hub that can be pretrained by this s
https://huggingface.co/models?filter=t5 https://huggingface.co/models?filter=t5
""" """
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import json
import logging import logging
import os import os
import sys import sys
...@@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): ...@@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
if __name__ == "__main__": def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
...@@ -522,9 +523,7 @@ if __name__ == "__main__": ...@@ -522,9 +523,7 @@ if __name__ == "__main__":
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer) model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
) )
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = T5Config.from_pretrained( config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
)
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
...@@ -617,6 +616,7 @@ if __name__ == "__main__": ...@@ -617,6 +616,7 @@ if __name__ == "__main__":
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
) )
else: else:
config.vocab_size = len(tokenizer)
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
# Data collator # Data collator
...@@ -808,3 +808,33 @@ if __name__ == "__main__": ...@@ -808,3 +808,33 @@ if __name__ == "__main__":
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()
...@@ -18,6 +18,7 @@ Fine-tuning the library models for question answering. ...@@ -18,6 +18,7 @@ Fine-tuning the library models for question answering.
""" """
# You can also adapt this script on your own question answering task. Pointers for this are left as comments. # You can also adapt this script on your own question answering task. Pointers for this are left as comments.
import json
import logging import logging
import os import os
import random import random
...@@ -911,6 +912,58 @@ def main(): ...@@ -911,6 +912,58 @@ def main():
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# endregion # endregion
# Eval after training
if training_args.do_eval:
eval_metrics = {}
all_start_logits = []
all_end_logits = []
eva_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = p_eval_step(state, batch)
start_logits = np.array([pred for pred in chain(*predictions[0])])
end_logits = np.array([pred for pred in chain(*predictions[1])])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = eval_step(unreplicate(state), batch)
start_logits = np.array([pred for pred in predictions[0]])
end_logits = np.array([pred for pred in predictions[1]])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
# concatenate the numpy array
start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)
# delete the list of numpy arrays
del all_start_logits
del all_end_logits
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
eval_metrics = compute_metrics(prediction)
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -18,6 +18,7 @@ Fine-tuning the library models for summarization. ...@@ -18,6 +18,7 @@ Fine-tuning the library models for summarization.
""" """
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import json
import logging import logging
import os import os
import sys import sys
...@@ -816,6 +817,13 @@ def main(): ...@@ -816,6 +817,13 @@ def main():
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc) logger.info(desc)
# save final metrics in json
if jax.process_index() == 0:
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
path = os.path.join(training_args.output_dir, "test_results.json")
with open(path, "w") as f:
json.dump(rouge_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
import sys
from unittest.mock import patch
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow
SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in [
"text-classification",
"language-modeling",
"summarization",
"token-classification",
"question-answering",
]
]
sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm_flax
import run_flax_glue
import run_flax_ner
import run_mlm_flax
import run_qa
import run_summarization_flax
import run_t5_mlm_flax
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
def get_results(output_dir, split="eval"):
results = {}
path = os.path.join(output_dir, f"{split}_results.json")
if os.path.exists(path):
with open(path, "r") as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")
return results
class ExamplesTests(TestCasePlus):
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--max_train_steps=10
--num_warmup_steps=2
--seed=42
--max_length=128
""".split()
with patch.object(sys, "argv", testargs):
run_flax_glue.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
def test_run_clm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm_flax.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--block_size 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--num_train_epochs 2
--logging_steps 2 --eval_steps 2
--output_dir {tmp_dir}
--overwrite_output_dir
""".split()
with patch.object(sys, "argv", testargs):
run_clm_flax.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_perplexity"], 100)
@slow
def test_run_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_summarization.py
--model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--test_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=50
--warmup_steps=8
--do_train
--do_eval
--do_predict
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()
with patch.object(sys, "argv", testargs):
run_summarization_flax.main()
result = get_results(tmp_dir, split="test")
self.assertGreaterEqual(result["test_rouge1"], 10)
self.assertGreaterEqual(result["test_rouge2"], 2)
self.assertGreaterEqual(result["test_rougeL"], 7)
self.assertGreaterEqual(result["test_rougeLsum"], 7)
def test_run_mlm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_mlm.py
--model_name_or_path distilroberta-base
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--output_dir {tmp_dir}
--overwrite_output_dir
--max_seq_length 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--logging_steps 2 --eval_steps 2
--do_train
--do_eval
--num_train_epochs=1
""".split()
with patch.object(sys, "argv", testargs):
run_mlm_flax.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_perplexity"], 42)
@slow
def test_run_t5_mlm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_t5_mlm_flax.py
--model_name_or_path t5-small
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--num_train_epochs 2
--logging_steps 2 --eval_steps 2
--output_dir {tmp_dir}
--overwrite_output_dir
""".split()
with patch.object(sys, "argv", testargs):
run_t5_mlm_flax.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.42)
def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_flax_ner.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--do_train
--do_eval
--warmup_steps=2
--learning_rate=2e-4
--logging_steps 2 --eval_steps 2
--per_device_train_batch_size=2
--per_device_eval_batch_size=2
--num_train_epochs={epochs}
--seed 7
""".split()
with patch.object(sys, "argv", testargs):
run_flax_ner.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertGreaterEqual(result["eval_f1"], 0.3)
def test_run_qa(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_qa.py
--model_name_or_path bert-base-uncased
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=10
--warmup_steps=2
--do_train
--do_eval
--logging_steps 2 --eval_steps 2
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
""".split()
with patch.object(sys, "argv", testargs):
run_qa.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE.""" """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import argparse import argparse
import json
import logging import logging
import os import os
import random import random
...@@ -522,6 +523,13 @@ def main(): ...@@ -522,6 +523,13 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
# save the eval metrics in json
if jax.process_index() == 0:
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
path = os.path.join(args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metric, f, indent=4, sort_keys=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)""" """ Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
import json
import logging import logging
import os import os
import random import random
...@@ -675,6 +676,42 @@ def main(): ...@@ -675,6 +676,42 @@ def main():
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# Eval after training
if training_args.do_eval:
eval_metrics = {}
eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
labels = batch.pop("labels")
predictions = p_eval_step(state, batch)
predictions = np.array([pred for pred in chain(*predictions)])
labels = np.array([label for label in chain(*labels)])
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs)
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
labels = np.array(batch.pop("labels"))
predictions = eval_step(unreplicate(state), batch)
labels[np.array(batch["attention_mask"]) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs)
eval_metrics = compute_metrics()
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -600,7 +600,7 @@ def require_deepspeed(test_case): ...@@ -600,7 +600,7 @@ def require_deepspeed(test_case):
def get_gpu_count(): def get_gpu_count():
""" """
Return the number of available gpus (regardless of whether torch or tf is used) Return the number of available gpus (regardless of whether torch, tf or jax is used)
""" """
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -610,6 +610,10 @@ def get_gpu_count(): ...@@ -610,6 +610,10 @@ def get_gpu_count():
import tensorflow as tf import tensorflow as tf
return len(tf.config.list_physical_devices("GPU")) return len(tf.config.list_physical_devices("GPU"))
elif is_flax_available():
import jax
return jax.device_count()
else: else:
return 0 return 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment