Unverified Commit 023f0f37 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s trainer] tests to use distributed on multi-gpu machine (#7965)

parent 64b24bb3
import os import os
import sys import sys
from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest
from transformers import is_torch_available
from transformers.testing_utils import TestCasePlus, slow from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
from .finetune_trainer import main from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY from .test_seq2seq_examples import MBART_TINY
from .utils import execute_async_std
if is_torch_available():
import torch
set_seed(42) set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
...@@ -25,7 +33,7 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -25,7 +33,7 @@ class TestFinetuneTrainer(TestCasePlus):
@slow @slow
def test_finetune_trainer_slow(self): def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere # There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10)
# Check metrics # Check metrics
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
...@@ -43,6 +51,8 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -43,6 +51,8 @@ class TestFinetuneTrainer(TestCasePlus):
assert "test_results.json" in contents assert "test_results.json" in contents
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
# XXX: remove hardcoded path
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()
argv = f""" argv = f"""
...@@ -77,8 +87,34 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -77,8 +87,34 @@ class TestFinetuneTrainer(TestCasePlus):
""".split() """.split()
# --eval_beams 2 # --eval_beams 2
testargs = ["finetune_trainer.py"] + argv n_gpu = torch.cuda.device_count()
with patch.object(sys, "argv", testargs): if n_gpu > 1:
main()
path = Path(__file__).resolve()
cur_path = path.parents[0]
path = Path(__file__).resolve()
examples_path = path.parents[1]
src_path = f"{path.parents[2]}/src"
env = os.environ.copy()
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
distributed_args = (
f"-m torch.distributed.launch --nproc_per_node={n_gpu} {cur_path}/finetune_trainer.py".split()
)
cmd = [sys.executable] + distributed_args + argv
print("\nRunning: ", " ".join(cmd))
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
assert result.stdout, "produced no output"
if result.returncode > 0:
pytest.fail(f"failed with returncode {result.returncode}")
else:
# 0 or 1 gpu
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
return output_dir return output_dir
...@@ -6,11 +6,15 @@ import sys ...@@ -6,11 +6,15 @@ import sys
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch
from transformers import is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch_multigpu from transformers.testing_utils import TestCasePlus, require_torch_multigpu
from .utils import load_json from .utils import execute_async_std, load_json
if is_torch_available():
import torch
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -106,73 +110,6 @@ def make_test_data_dir(tmp_dir): ...@@ -106,73 +110,6 @@ def make_test_data_dir(tmp_dir):
return tmp_dir return tmp_dir
# XXX: a candidate for testing_utils (python>=3.6)
# https://stackoverflow.com/a/59041913/9201239
import asyncio # noqa
class RunOutput:
def __init__(self, returncode, stdout, stderr):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
async def _read_stream(stream, callback):
while True:
line = await stream.readline()
if line:
callback(line)
else:
break
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
if echo:
print(cmd)
p = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdin=stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
out = []
err = []
def tee(line, sink, pipe, label=""):
line = line.decode("utf-8").rstrip()
sink.append(line)
if not quiet:
print(label, line, file=pipe)
await asyncio.wait(
[
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
],
timeout=timeout,
)
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
return RunOutput(await p.wait(), out, err)
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
loop = asyncio.get_event_loop()
result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
)
return result
class TestSummarizationDistillerMultiGPU(TestCasePlus): class TestSummarizationDistillerMultiGPU(TestCasePlus):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -220,17 +157,18 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus): ...@@ -220,17 +157,18 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
return f"--{k}" return f"--{k}"
return f"--{k}={v}" return f"--{k}={v}"
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
cmd = [sys.executable, "./examples/seq2seq/distillation.py"] + cli_args
print("\nRunning: ", " ".join(cmd))
path = Path(__file__).resolve() path = Path(__file__).resolve()
cur_path = path.parents[0]
examples_path = path.parents[1] examples_path = path.parents[1]
src_path = f"{path.parents[2]}/src" src_path = f"{path.parents[2]}/src"
env = os.environ.copy() env = os.environ.copy()
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}" env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
cmd = [sys.executable, f"{cur_path}/distillation.py"] + cli_args
print("\nRunning: ", " ".join(cmd))
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False) result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
assert result.stdout, "produced no output" assert result.stdout, "produced no output"
......
...@@ -5,6 +5,7 @@ import math ...@@ -5,6 +5,7 @@ import math
import os import os
import pickle import pickle
import socket import socket
import sys
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, List, Tuple, Union from typing import Callable, Dict, Iterable, List, Tuple, Union
...@@ -643,3 +644,71 @@ def check_output_dir(args, expected_items=0): ...@@ -643,3 +644,71 @@ def check_output_dir(args, expected_items=0):
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). " "has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
# the following code deals with async io between processes
# adapted from https://stackoverflow.com/a/59041913/9201239
import asyncio # noqa
class _RunOutput:
def __init__(self, returncode, stdout, stderr):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
async def _read_stream(stream, callback):
while True:
line = await stream.readline()
if line:
callback(line)
else:
break
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
if echo:
print(cmd)
p = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdin=stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
out = []
err = []
def tee(line, sink, pipe, label=""):
line = line.decode("utf-8").rstrip()
sink.append(line)
if not quiet:
print(label, line, file=pipe)
await asyncio.wait(
[
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
],
timeout=timeout,
)
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
return _RunOutput(await p.wait(), out, err)
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
loop = asyncio.get_event_loop()
result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment