"docs/vscode:/vscode.git/clone" did not exist on "8bbd41369fc9944a1a2bf6c4ff1053c4648f42aa"
Unverified Commit 01b14669 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TPU tests] Enable first TPU examples pytorch (#14121)

* up

* up

* fix

* up

* Update examples/pytorch/test_xla_examples.py

* correct labels

* up

* up

* up

* up

* up

* up
parent 232822f3
...@@ -181,6 +181,45 @@ jobs: ...@@ -181,6 +181,45 @@ jobs:
name: run_all_tests_tf_gpu_test_reports name: run_all_tests_tf_gpu_test_reports
path: reports path: reports
run_all_examples_torch_xla_tpu:
runs-on: [self-hosted, docker-tpu-test, tpu-v3-8]
container:
image: gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
options: --privileged -v "/lib/libtpu.so:/lib/libtpu.so" -v /mnt/cache/.cache/huggingface:/mnt/cache/ --shm-size 16G
steps:
- name: Launcher docker
uses: actions/checkout@v2
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[testing]
- name: Are TPUs recognized by our DL frameworks
env:
XRT_TPU_CONFIG: localservice;0;localhost:51011
run: |
python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())"
- name: Run example tests on TPU
env:
XRT_TPU_CONFIG: "localservice;0;localhost:51011"
MKL_SERVICE_FORCE_INTEL: "1" # See: https://github.com/pytorch/pytorch/issues/37377
run: |
python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_xla_tpu examples/pytorch/test_xla_examples.py
- name: Failure short reports
if: ${{ always() }}
run: cat reports/tests_torch_xla_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: run_all_examples_torch_xla_tpu
path: reports
run_all_tests_torch_multi_gpu: run_all_tests_torch_multi_gpu:
runs-on: [self-hosted, docker-gpu, multi-gpu] runs-on: [self-hosted, docker-gpu, multi-gpu]
container: container:
......
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
# limitations under the License. # limitations under the License.
import json
import logging import logging
import os
import sys import sys
import unittest
from time import time from time import time
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import require_torch_tpu from transformers.testing_utils import TestCasePlus, require_torch_tpu
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -28,66 +29,65 @@ logging.basicConfig(level=logging.DEBUG) ...@@ -28,66 +29,65 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger() logger = logging.getLogger()
def get_results(output_dir):
results = {}
path = os.path.join(output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")
return results
@require_torch_tpu @require_torch_tpu
class TorchXLAExamplesTests(unittest.TestCase): class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self): def test_run_glue(self):
import xla_spawn import xla_spawn
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
output_directory = "run_glue_output" tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
transformers/examples/text-classification/run_glue.py ./examples/pytorch/text-classification/run_glue.py
--num_cores=8 --num_cores=8
transformers/examples/text-classification/run_glue.py ./examples/pytorch/text-classification/run_glue.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--overwrite_output_dir
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--do_train --do_train
--do_eval --do_eval
--task_name=mrpc --debug tpu_metrics_debug
--cache_dir=./cache_dir --per_device_train_batch_size=2
--num_train_epochs=1 --per_device_eval_batch_size=1
--learning_rate=1e-4
--max_steps=10
--warmup_steps=2
--seed=42
--max_seq_length=128 --max_seq_length=128
--learning_rate=3e-5
--output_dir={output_directory}
--overwrite_output_dir
--logging_steps=5
--save_steps=5
--overwrite_cache
--tpu_metrics_debug
--model_name_or_path=bert-base-cased
--per_device_train_batch_size=64
--per_device_eval_batch_size=64
--evaluation_strategy steps
--overwrite_cache
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
start = time() start = time()
xla_spawn.main() xla_spawn.main()
end = time() end = time()
result = {} result = get_results(tmp_dir)
with open(f"{output_directory}/eval_results_mrpc.txt") as f: self.assertGreaterEqual(result["eval_accuracy"], 0.75)
lines = f.readlines()
for line in lines:
key, value = line.split(" = ")
result[key] = float(value)
del result["eval_loss"]
for value in result.values():
# Assert that the model trains
self.assertGreaterEqual(value, 0.70)
# Assert that the script takes less than 300 seconds to make sure it doesn't hang. # Assert that the script takes less than 500 seconds to make sure it doesn't hang.
self.assertLess(end - start, 500) self.assertLess(end - start, 500)
def test_trainer_tpu(self): def test_trainer_tpu(self):
import xla_spawn import xla_spawn
testargs = """ testargs = """
transformers/tests/test_trainer_tpu.py ./tests/test_trainer_tpu.py
--num_cores=8 --num_cores=8
transformers/tests/test_trainer_tpu.py ./tests/test_trainer_tpu.py
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
xla_spawn.main() xla_spawn.main()
...@@ -99,7 +99,7 @@ def main(): ...@@ -99,7 +99,7 @@ def main():
p = trainer.predict(dataset) p = trainer.predict(dataset)
logger.info(p.metrics) logger.info(p.metrics)
if p.metrics["eval_success"] is not True: if p.metrics["test_success"] is not True:
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
...@@ -113,7 +113,7 @@ def main(): ...@@ -113,7 +113,7 @@ def main():
p = trainer.predict(dataset) p = trainer.predict(dataset)
logger.info(p.metrics) logger.info(p.metrics)
if p.metrics["eval_success"] is not True: if p.metrics["test_success"] is not True:
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment