"...resnet50_tensorflow.git" did not exist on "83f56818af03cf70ec8e99d9c8c955ca98dead6e"
Unverified Commit 8b381733 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[seq2seq testing] multigpu test run via subprocess (#7281)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent f8d3695e
...@@ -170,7 +170,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -170,7 +170,7 @@ class BaseTransformer(pl.LightningModule):
self.dataset_size = len(self.test_dataloader().dataset) self.dataset_size = len(self.test_dataloader().dataset)
else: else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
self.dataset_size = len(self.train_loader.dataset) self.dataset_size = len(self.train_dataloader().dataset)
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False): def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
raise NotImplementedError("You must implement this for your task") raise NotImplementedError("You must implement this for your task")
......
...@@ -17,7 +17,7 @@ from finetune import main as ft_main ...@@ -17,7 +17,7 @@ from finetune import main as ft_main
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, use_task_specific_params from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params
# need the parent dir module # need the parent dir module
...@@ -266,8 +266,7 @@ def create_module(args): ...@@ -266,8 +266,7 @@ def create_module(args):
def distill_main(args): def distill_main(args):
Path(args.output_dir).mkdir(exist_ok=True) Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train: check_output_dir(args, expected_items=3)
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
model = create_module(args) model = create_module(args)
return ft_main(args, model=model) return ft_main(args, model=model)
......
...@@ -25,6 +25,7 @@ from utils import ( ...@@ -25,6 +25,7 @@ from utils import (
assert_all_frozen, assert_all_frozen,
calculate_bleu, calculate_bleu,
calculate_rouge, calculate_rouge,
check_output_dir,
flatten_list, flatten_list,
freeze_embeds, freeze_embeds,
freeze_params, freeze_params,
...@@ -329,6 +330,7 @@ class SummarizationModule(BaseTransformer): ...@@ -329,6 +330,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--freeze_encoder", action="store_true") parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true") parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False) parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
parser.add_argument("--max_tokens_per_batch", type=int, default=None) parser.add_argument("--max_tokens_per_batch", type=int, default=None)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
...@@ -373,8 +375,8 @@ class TranslationModule(SummarizationModule): ...@@ -373,8 +375,8 @@ class TranslationModule(SummarizationModule):
def main(args, model=None) -> SummarizationModule: def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True) Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train: check_output_dir(args, expected_items=3)
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None: if model is None:
if "summarization" in args.task: if "summarization" in args.task:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
......
...@@ -21,6 +21,7 @@ from utils import ( ...@@ -21,6 +21,7 @@ from utils import (
Seq2SeqDataset, Seq2SeqDataset,
assert_all_frozen, assert_all_frozen,
build_compute_metrics_fn, build_compute_metrics_fn,
check_output_dir,
freeze_embeds, freeze_embeds,
freeze_params, freeze_params,
lmap, lmap,
...@@ -153,15 +154,7 @@ def main(): ...@@ -153,15 +154,7 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( check_output_dir(training_args)
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
......
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
import logging
import os
import sys
from pathlib import Path
import pytest
import torch
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
from .utils import load_json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"max_tokens_per_batch": None,
"supervise_forward": True,
"normalize_hidden": True,
"label_smoothing": 0.2,
"eval_max_gen_length": None,
"eval_beams": 1,
"val_metric": "loss",
"save_top_k": 1,
"adafactor": True,
"early_stopping_patience": 2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"accumulate_grad_batches": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"do_lower_case": False,
"learning_rate": 0.3,
"lr_scheduler": "linear",
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"max_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
def _dump_articles(path: Path, articles: list):
content = "\n".join(articles)
Path(path).open("w").writelines(content)
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(tmp_dir):
for split in ["train", "val", "test"]:
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir
# XXX: a candidate for testing_utils (python>=3.6)
# https://stackoverflow.com/a/59041913/9201239
import asyncio # noqa
class RunOutput:
def __init__(self, returncode, stdout, stderr):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
async def _read_stream(stream, callback):
while True:
line = await stream.readline()
if line:
callback(line)
else:
break
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
if echo:
print(cmd)
p = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdin=stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
out = []
err = []
def tee(line, sink, pipe, label=""):
line = line.decode("utf-8").rstrip()
sink.append(line)
if not quiet:
print(label, line, file=pipe)
await asyncio.wait(
[
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
],
timeout=timeout,
)
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
return RunOutput(await p.wait(), out, err)
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
loop = asyncio.get_event_loop()
result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
)
return result
class TestSummarizationDistillerMultiGPU(TestCasePlus):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@require_torch_multigpu
def test_multigpu(self):
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
overwrite_output_dir=True,
sortish_sampler=True,
)
self._test_distiller_cli_fork(updates, check_contents=False)
def _test_distiller_cli_fork(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
def convert(k, v):
if k in ["tgt_suffix", "server_ip", "server_port", "out", "n_tpu_cores"]:
return ""
if v is False or v is None:
return ""
if v is True: # or len(str(v))==0:
return f"--{k}"
return f"--{k}={v}"
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
cmd = [sys.executable, "./examples/seq2seq/distillation.py"] + cli_args
print("\nRunning: ", " ".join(cmd))
path = Path(__file__).resolve()
examples_path = path.parents[1]
src_path = f"{path.parents[2]}/src"
env = os.environ.copy()
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
assert result.stdout, "produced no output"
if result.returncode > 0:
pytest.fail(f"failed with returncode {result.returncode}")
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
ckpt_files = [p for p in contents if p.endswith("ckpt")]
assert len(ckpt_files) > 0
self.assertIn("test_generations.txt", contents)
self.assertIn("test_results.txt", contents)
# get the following from the module, (we don't have access to `model` here)
metrics_save_path = os.path.join(output_dir, "metrics.json")
val_metric = "rouge2"
metrics = load_json(metrics_save_path)
# {'test': [{'test_avg_loss': 10.63731575012207, 'test_avg_rouge1': 0.0, 'test_avg_rouge2': 0.0, 'test_avg_rougeL': 0.0, 'test_avg_gen_time': 0.1822289228439331, 'test_avg_gen_len': 142.0, 'step_count': 1}]}
print(metrics)
last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
self.assertEqual(len(metrics["test"]), 1)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
...@@ -619,3 +619,27 @@ def chunks(lst, n): ...@@ -619,3 +619,27 @@ def chunks(lst, n):
"""Yield successive n-sized chunks from lst.""" """Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n): for i in range(0, len(lst), n):
yield lst[i : i + n] yield lst[i : i + n]
def check_output_dir(args, expected_items=0):
"""
Checks whether to bail out if output_dir already exists and has more than expected_items in it
`args`: needs to have the following attributes of `args`:
- output_dir
- do_train
- overwrite_output_dir
`expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM)
"""
if (
os.path.exists(args.output_dir)
and len(os.listdir(args.output_dir)) > expected_items
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({args.output_dir}) already exists and "
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
"Use --overwrite_output_dir to overcome."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment