Unverified Commit 9f7b2b24 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)

parent dc552b9b
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import argparse import argparse
import os import os
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
...@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main ...@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import load_json from utils import load_json
...@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY ...@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow class TestAll(TestCasePlus):
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @slow
def test_model_download(): @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" def test_model_download(self):
BartForConditionalGeneration.from_pretrained(MODEL_NAME) """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel.from_pretrained(MARIAN_MODEL) BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120) @timeout_decorator.timeout(120)
@slow @slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
def test_train_mbart_cc25_enro_script(): def test_train_mbart_cc25_enro_script(self):
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"--fp16_opt_level=O1": "", "--fp16_opt_level=O1": "",
"$MAX_LEN": 128, "$MAX_LEN": 128,
"$BS": 4, "$BS": 4,
"$GAS": 1, "$GAS": 1,
"$ENRO_DIR": data_dir, "$ENRO_DIR": data_dir,
"facebook/mbart-large-cc25": MODEL_NAME, "facebook/mbart-large-cc25": MODEL_NAME,
# Download is 120MB in previous test. # Download is 120MB in previous test.
"val_check_interval=0.25": "val_check_interval=1.0", "val_check_interval=0.25": "val_check_interval=1.0",
} }
# Clean up bash script # Clean up bash script
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items(): for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="output_mbart") output_dir = self.get_auto_remove_tmp_dir()
bash_script = bash_script.replace("--fp16 ", "") bash_script = bash_script.replace("--fp16 ", "")
testargs = ( testargs = (
["finetune.py"] ["finetune.py"]
+ bash_script.split() + bash_script.split()
+ [ + [
f"--output_dir={output_dir}", f"--output_dir={output_dir}",
"--gpus=1", "--gpus=1",
"--learning_rate=3e-1", "--learning_rate=3e-1",
"--warmup_steps=0", "--warmup_steps=0",
"--val_check_interval=1.0", "--val_check_interval=1.0",
"--tokenizer_name=facebook/mbart-large-en-ro", "--tokenizer_name=facebook/mbart-large-en-ro",
] ]
) )
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args() args = parser.parse_args()
args.do_predict = False args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu # assert args.gpus == gpus THIS BREAKS for multigpu
model = main(args) model = main(args)
# Check metrics # Check metrics
metrics = load_json(model.metrics_save_path) metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0] first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1] last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check assert (
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
assert last_step_stats["val_avg_gen_time"] >= 0.01 ) # +1 accounts for val_sanity_check
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing assert last_step_stats["val_avg_gen_time"] >= 0.01
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
# check lightning ckpt can be loaded and has a reasonable statedict assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
contents = os.listdir(output_dir)
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] # check lightning ckpt can be loaded and has a reasonable statedict
full_path = os.path.join(args.output_dir, ckpt_path) contents = os.listdir(output_dir)
ckpt = torch.load(full_path, map_location="cpu") ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" full_path = os.path.join(args.output_dir, ckpt_path)
assert expected_key in ckpt["state_dict"] ckpt = torch.load(full_path, map_location="cpu")
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
# TODO: turn on args.do_predict when PL bug fixed. assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
if args.do_predict:
contents = {os.path.basename(p) for p in contents} # TODO: turn on args.do_predict when PL bug fixed.
assert "test_generations.txt" in contents if args.do_predict:
assert "test_results.txt" in contents contents = {os.path.basename(p) for p in contents}
# assert len(metrics["val"]) == desired_n_evals assert "test_generations.txt" in contents
assert len(metrics["test"]) == 1 assert "test_results.txt" in contents
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
@timeout_decorator.timeout(600)
@slow @timeout_decorator.timeout(600)
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @slow
def test_opus_mt_distill_script(): @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
data_dir = "examples/seq2seq/test_data/wmt_en_ro" def test_opus_mt_distill_script(self):
env_vars_to_replace = { data_dir = "examples/seq2seq/test_data/wmt_en_ro"
"--fp16_opt_level=O1": "", env_vars_to_replace = {
"$MAX_LEN": 128, "--fp16_opt_level=O1": "",
"$BS": 16, "$MAX_LEN": 128,
"$GAS": 1, "$BS": 16,
"$ENRO_DIR": data_dir, "$GAS": 1,
"$m": "sshleifer/student_marian_en_ro_6_1", "$ENRO_DIR": data_dir,
"val_check_interval=0.25": "val_check_interval=1.0", "$m": "sshleifer/student_marian_en_ro_6_1",
} "val_check_interval=0.25": "val_check_interval=1.0",
}
# Clean up bash script
bash_script = ( # Clean up bash script
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() bash_script = (
) Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") )
bash_script = bash_script.replace("--fp16 ", " ") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) for k, v in env_vars_to_replace.items():
output_dir = tempfile.mkdtemp(prefix="marian_output") bash_script = bash_script.replace(k, str(v))
bash_script = bash_script.replace("--fp16", "") output_dir = self.get_auto_remove_tmp_dir()
epochs = 6 bash_script = bash_script.replace("--fp16", "")
testargs = ( epochs = 6
["distillation.py"] testargs = (
+ bash_script.split() ["distillation.py"]
+ [ + bash_script.split()
f"--output_dir={output_dir}", + [
"--gpus=1", f"--output_dir={output_dir}",
"--learning_rate=1e-3", "--gpus=1",
f"--num_train_epochs={epochs}", "--learning_rate=1e-3",
"--warmup_steps=10", f"--num_train_epochs={epochs}",
"--val_check_interval=1.0", "--warmup_steps=10",
] "--val_check_interval=1.0",
) ]
with patch.object(sys, "argv", testargs): )
parser = argparse.ArgumentParser() with patch.object(sys, "argv", testargs):
parser = pl.Trainer.add_argparse_args(parser) parser = argparse.ArgumentParser()
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args() parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args.do_predict = False args = parser.parse_args()
# assert args.gpus == gpus THIS BREAKS for multigpu args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = distill_main(args)
model = distill_main(args)
# Check metrics
metrics = load_json(model.metrics_save_path) # Check metrics
first_step_stats = metrics["val"][0] metrics = load_json(model.metrics_save_path)
last_step_stats = metrics["val"][-1] first_step_stats = metrics["val"][0]
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01
assert last_step_stats["val_avg_gen_time"] >= 0.01
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
# check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir) # check lightning ckpt can be loaded and has a reasonable statedict
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] contents = os.listdir(output_dir)
full_path = os.path.join(args.output_dir, ckpt_path) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
ckpt = torch.load(full_path, map_location="cpu") full_path = os.path.join(args.output_dir, ckpt_path)
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" ckpt = torch.load(full_path, map_location="cpu")
assert expected_key in ckpt["state_dict"] expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO: turn on args.do_predict when PL bug fixed.
if args.do_predict: # TODO: turn on args.do_predict when PL bug fixed.
contents = {os.path.basename(p) for p in contents} if args.do_predict:
assert "test_generations.txt" in contents contents = {os.path.basename(p) for p in contents}
assert "test_results.txt" in contents assert "test_generations.txt" in contents
# assert len(metrics["val"]) == desired_n_evals assert "test_results.txt" in contents
assert len(metrics["test"]) == 1 # assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
import os import os
import tempfile
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -7,11 +6,12 @@ import pytest ...@@ -7,11 +6,12 @@ import pytest
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir from pack_dataset import pack_data_dir
from parameterized import parameterized
from save_len_file import save_len_file from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
...@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased" ...@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum" PEGASUS_XSUM = "google/pegasus-xsum"
@slow class TestAll(TestCasePlus):
@pytest.mark.parametrize( @parameterized.expand(
"tok_name", [
[ MBART_TINY,
MBART_TINY, MARIAN_TINY,
MARIAN_TINY, T5_TINY,
T5_TINY, BART_TINY,
BART_TINY, PEGASUS_XSUM,
PEGASUS_XSUM, ],
],
)
def test_seq2seq_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
max_src_len = 4
max_tgt_len = 8
assert max_len_target > max_src_len # Will be truncated
assert max_len_source > max_src_len # Will be truncated
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=max_src_len,
max_target_length=max_tgt_len, # ignored
src_lang=src_lang,
tgt_lang=tgt_lang,
) )
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) @slow
for batch in dataloader: def test_seq2seq_dataset_truncation(self, tok_name):
assert isinstance(batch, dict) tokenizer = AutoTokenizer.from_pretrained(tok_name)
assert batch["attention_mask"].shape == batch["input_ids"].shape tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
# show that articles were trimmed. max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
assert batch["input_ids"].shape[1] == max_src_len max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
# show that targets are the same len max_src_len = 4
assert batch["labels"].shape[1] == max_tgt_len max_tgt_len = 8
if tok_name != MBART_TINY: assert max_len_target > max_src_len # Will be truncated
continue assert max_len_source > max_src_len # Will be truncated
# check language codes in correct place src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
break # No need to test every batch
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir())
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()}
packed_examples = save_dir.joinpath("train.source").open().readlines()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert len(packed_examples) < len(orig_examples)
assert len(packed_examples) == 1
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
def test_dynamic_batch_size():
if not FAIRSEQ_AVAILABLE:
return
ds, max_tokens, tokenizer = _get_dataset(max_len=64)
required_batch_size_multiple = 64
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
batch_sizes = [len(x) for x in batch_sampler]
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
assert sum(batch_sizes) == len(ds) # no dropped or added examples
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
failures = []
num_src_per_batch = []
for batch in data_loader:
src_shape = batch["input_ids"].shape
bs = src_shape[0]
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
num_src_tokens = np.product(batch["input_ids"].shape)
num_src_per_batch.append(num_src_tokens)
if num_src_tokens > (max_tokens * 1.1):
failures.append(num_src_tokens)
assert num_src_per_batch[0] == max(num_src_per_batch)
if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches")
def test_sortish_sampler_reduces_padding():
ds, _, tokenizer = _get_dataset(max_len=512)
bs = 2
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
pad = tokenizer.pad_token_id
def count_pad_tokens(data_loader, k="input_ids"):
return [batch[k].eq(pad).sum().item() for batch in data_loader]
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
assert len(sortish_dl) == len(naive_dl)
def _get_dataset(n_obs=1000, max_len=128):
if os.getenv("USE_REAL_DATA", False):
data_dir = "examples/seq2seq/wmt_en_ro"
max_tokens = max_len * 2 * 64
if not Path(data_dir).joinpath("train.len").exists():
save_len_file(MARIAN_TINY, data_dir)
else:
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
max_tokens = max_len * 4
save_len_file(MARIAN_TINY, data_dir)
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
ds = Seq2SeqDataset(
tokenizer,
data_dir=data_dir,
type_path="train",
max_source_length=max_len,
max_target_length=max_len,
n_obs=n_obs,
)
return ds, max_tokens, tokenizer
def test_distributed_sortish_sampler_splits_indices_between_procs():
ds, max_tokens, tokenizer = _get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@pytest.mark.parametrize(
"tok_name",
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset( train_dataset = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir=make_test_data_dir(), data_dir=tmp_dir,
type_path="train", type_path="train",
max_source_length=4, max_source_length=max_src_len,
max_target_length=8, max_target_length=max_tgt_len, # ignored
src_lang="EN", src_lang=src_lang,
tgt_lang="FR", tgt_lang=tgt_lang,
) )
kwargs = train_dataset.dataset_kwargs dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
assert "src_lang" in kwargs and "tgt_lang" in kwargs for batch in dataloader:
else: assert isinstance(batch, dict)
train_dataset = Seq2SeqDataset( assert batch["attention_mask"].shape == batch["input_ids"].shape
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8 # show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["labels"].shape[1] == max_tgt_len
if tok_name != MBART_TINY:
continue
# check language codes in correct place
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
break # No need to test every batch
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(self, tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
) )
kwargs = train_dataset.dataset_kwargs dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs for batch in dataloader:
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0 assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
def test_pack_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()}
packed_examples = save_dir.joinpath("train.source").open().readlines()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert len(packed_examples) < len(orig_examples)
assert len(packed_examples) == 1
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
def test_dynamic_batch_size(self):
if not FAIRSEQ_AVAILABLE:
return
ds, max_tokens, tokenizer = self._get_dataset(max_len=64)
required_batch_size_multiple = 64
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
batch_sizes = [len(x) for x in batch_sampler]
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
assert sum(batch_sizes) == len(ds) # no dropped or added examples
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
failures = []
num_src_per_batch = []
for batch in data_loader:
src_shape = batch["input_ids"].shape
bs = src_shape[0]
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
num_src_tokens = np.product(batch["input_ids"].shape)
num_src_per_batch.append(num_src_tokens)
if num_src_tokens > (max_tokens * 1.1):
failures.append(num_src_tokens)
assert num_src_per_batch[0] == max(num_src_per_batch)
if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches")
def test_sortish_sampler_reduces_padding(self):
ds, _, tokenizer = self._get_dataset(max_len=512)
bs = 2
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
pad = tokenizer.pad_token_id
def count_pad_tokens(data_loader, k="input_ids"):
return [batch[k].eq(pad).sum().item() for batch in data_loader]
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
assert len(sortish_dl) == len(naive_dl)
def _get_dataset(self, n_obs=1000, max_len=128):
if os.getenv("USE_REAL_DATA", False):
data_dir = "examples/seq2seq/wmt_en_ro"
max_tokens = max_len * 2 * 64
if not Path(data_dir).joinpath("train.len").exists():
save_len_file(MARIAN_TINY, data_dir)
else:
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
max_tokens = max_len * 4
save_len_file(MARIAN_TINY, data_dir)
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
ds = Seq2SeqDataset(
tokenizer,
data_dir=data_dir,
type_path="train",
max_source_length=max_len,
max_target_length=max_len,
n_obs=n_obs,
)
return ds, max_tokens, tokenizer
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
ds, max_tokens, tokenizer = self._get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@parameterized.expand(
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train",
max_source_length=4,
max_target_length=8,
src_lang="EN",
tgt_lang="FR",
)
kwargs = train_dataset.dataset_kwargs
assert "src_lang" in kwargs and "tgt_lang" in kwargs
else:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train",
max_source_length=4,
max_target_length=8,
)
kwargs = train_dataset.dataset_kwargs
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
import os import os
import sys import sys
import tempfile
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
...@@ -15,72 +14,71 @@ set_seed(42) ...@@ -15,72 +14,71 @@ set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
def test_finetune_trainer(): class TestFinetuneTrainer(TestCasePlus):
output_dir = run_trainer(1, "12", MBART_TINY, 1) def test_finetune_trainer(self):
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
first_step_stats = eval_metrics[0] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
assert "eval_bleu" in first_step_stats first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats
@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
@slow # Check metrics
def test_finetune_trainer_slow(): logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
# There is a missing call to __init__process_group somewhere eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
# Check metrics assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history assert isinstance(last_step_stats["eval_bleu"], float)
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing # test if do_predict saves generations and metrics
assert isinstance(last_step_stats["eval_bleu"], float) contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.json" in contents
# test if do_predict saves generations and metrics def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
contents = os.listdir(output_dir) data_dir = "examples/seq2seq/test_data/wmt_en_ro"
contents = {os.path.basename(p) for p in contents} output_dir = self.get_auto_remove_tmp_dir()
assert "test_generations.txt" in contents argv = f"""
assert "test_results.json" in contents --model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): return output_dir
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = tempfile.mkdtemp(prefix="test_output")
argv = f"""
--model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
return output_dir
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
import os import os
import sys import sys
import tempfile import tempfile
import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
...@@ -15,11 +14,12 @@ import lightning_base ...@@ -15,11 +14,12 @@ import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main from distillation import distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from parameterized import parameterized
from run_eval import generate_summaries_or_translations, run_generate from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
...@@ -52,7 +52,7 @@ CHEAP_ARGS = { ...@@ -52,7 +52,7 @@ CHEAP_ARGS = {
"student_decoder_layers": 1, "student_decoder_layers": 1,
"val_check_interval": 1.0, "val_check_interval": 1.0,
"output_dir": "", "output_dir": "",
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp "fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher": False, "no_teacher": False,
"fp16_opt_level": "O1", "fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0, "gpus": 1 if CUDA_AVAILABLE else 0,
...@@ -88,6 +88,7 @@ CHEAP_ARGS = { ...@@ -88,6 +88,7 @@ CHEAP_ARGS = {
"student_encoder_layers": 1, "student_encoder_layers": 1,
"freeze_encoder": False, "freeze_encoder": False,
"auto_scale_batch_size": False, "auto_scale_batch_size": False,
"overwrite_output_dir": False,
} }
...@@ -110,15 +111,14 @@ logger.addHandler(stream_handler) ...@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(**kwargs): def make_test_data_dir(tmp_dir):
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
for split in ["train", "val", "test"]: for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES) _dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES) _dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir return tmp_dir
class TestSummarizationDistiller(unittest.TestCase): class TestSummarizationDistiller(TestCasePlus):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
...@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
failures.append(m) failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}" assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
@require_multigpu
@unittest.skip("Broken at the moment")
def test_multigpu(self):
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
sortish_sampler=True,
)
self._test_distiller_cli(updates, check_contents=False)
def test_distill_no_teacher(self): def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
...@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertEqual(1, len(ckpts)) self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), 2) self.assertEqual(len(transformer_ckpts), 2)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
out_path = tempfile.mktemp() out_path = tempfile.mktemp() # XXX: not being cleaned up
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr")) generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
self.assertTrue(Path(out_path).exists()) self.assertTrue(Path(out_path).exists())
out_path_new = tempfile.mkdtemp() out_path_new = self.get_auto_remove_tmp_dir()
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new) convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin")) assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
...@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
) )
default_updates.update(updates) default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = tempfile.mkdtemp(prefix="output_") output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d)) model = distill_main(argparse.Namespace(**args_d))
...@@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase):
return model return model
def run_eval_tester(model): class TestTheRest(TestCasePlus):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" def run_eval_tester(self, model):
output_file_name = input_file_name.parent / "utest_output.txt" input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
assert not output_file_name.exists() output_file_name = input_file_name.parent / "utest_output.txt"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] assert not output_file_name.exists()
_dump_articles(input_file_name, articles) articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
score_path = str(Path(tempfile.mkdtemp()) / "scores.json") _dump_articles(input_file_name, articles)
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = f""" score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json")
run_eval_search.py task = "translation_en_to_de" if model == T5_TINY else "summarization"
{model} testargs = f"""
{input_file_name} run_eval_search.py
{output_file_name} {model}
--score_path {score_path} {input_file_name}
--task {task} {output_file_name}
--num_beams 2 --score_path {score_path}
--length_penalty 2.0 --task {task}
""".split() --num_beams 2
--length_penalty 2.0
with patch.object(sys, "argv", testargs): """.split()
run_generate()
assert Path(output_file_name).exists() with patch.object(sys, "argv", testargs):
os.remove(Path(output_file_name)) run_generate()
assert Path(output_file_name).exists()
# os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately # test one model to quickly (no-@slow) catch simple problems and do an
def test_run_eval(): # extensive testing of functionality with multiple models as @slow separately
run_eval_tester(T5_TINY) def test_run_eval(self):
self.run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow # any extra models should go into the list here - can be slow
@slow @parameterized.expand([BART_TINY, MBART_TINY])
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY]) @slow
def test_run_eval_slow(model): def test_run_eval_slow(self, model):
run_eval_tester(model) self.run_eval_tester(model)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@slow
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
def test_run_eval_search(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
text = {
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
"de": [
"Maschinelles Lernen ist großartig, oder?",
"Ich esse gerne Bananen",
"Morgen ist wieder ein toller Tag!",
],
}
tmp_dir = Path(tempfile.mkdtemp())
score_path = str(tmp_dir / "scores.json")
reference_path = str(tmp_dir / "val.target")
_dump_articles(input_file_name, text["en"])
_dump_articles(reference_path, text["de"])
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = f"""
run_eval_search.py
{model}
{str(input_file_name)}
{str(output_file_name)}
--score_path {score_path}
--reference_path {reference_path}
--task {task}
""".split()
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
with patch.object(sys, "argv", testargs):
with CaptureStdout() as cs:
run_search()
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
un_expected_strings = ["Info"]
if "translation" in task:
expected_strings.append("bleu")
else:
expected_strings.extend(ROUGE_KEYS)
for w in expected_strings:
assert w in cs.out
for w in un_expected_strings:
assert w not in cs.out
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@pytest.mark.parametrize(
"model",
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
elif model == FSMT_TINY:
fsmt = module.model.model
embed_pos = fsmt.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not fsmt.decoder.embed_tokens.weight.requires_grad
# check that embeds are not the same
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
example_batch = load_json(module.output_dir / "text_batch.json")
assert isinstance(example_batch, dict)
assert len(example_batch) >= 4
def test_finetune_extra_model_args():
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir()
args_d.update(
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args # testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
model = BART_TINY @parameterized.expand([T5_TINY, MBART_TINY])
output_dir = tempfile.mkdtemp(prefix="output_1_") @slow
args_d1 = args_d.copy() def test_run_eval_search(self, model):
args_d1.update( input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
model_name_or_path=model, output_file_name = input_file_name.parent / "utest_output.txt"
output_dir=output_dir, assert not output_file_name.exists()
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") text = {
for p in extra_model_params: "en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
args_d1[p] = 0.5 "de": [
args = argparse.Namespace(**args_d1) "Maschinelles Lernen ist großartig, oder?",
model = main(args) "Ich esse gerne Bananen",
for p in extra_model_params: "Morgen ist wieder ein toller Tag!",
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}" ],
}
# test models whose config doesn't include the extra_model_args
model = T5_TINY tmp_dir = Path(self.get_auto_remove_tmp_dir())
output_dir = tempfile.mkdtemp(prefix="output_2_") score_path = str(tmp_dir / "scores.json")
args_d2 = args_d.copy() reference_path = str(tmp_dir / "val.target")
args_d2.update( _dump_articles(input_file_name, text["en"])
model_name_or_path=model, _dump_articles(reference_path, text["de"])
output_dir=output_dir, task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = f"""
run_eval_search.py
{model}
{str(input_file_name)}
{str(output_file_name)}
--score_path {score_path}
--reference_path {reference_path}
--task {task}
""".split()
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
with patch.object(sys, "argv", testargs):
with CaptureStdout() as cs:
run_search()
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
un_expected_strings = ["Info"]
if "translation" in task:
expected_strings.append("bleu")
else:
expected_strings.extend(ROUGE_KEYS)
for w in expected_strings:
assert w in cs.out
for w in un_expected_strings:
assert w not in cs.out
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@parameterized.expand(
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
) )
unsupported_param = "encoder_layerdrop" def test_finetune(self, model):
args_d2[unsupported_param] = 0.5 args_d: dict = CHEAP_ARGS.copy()
args = argparse.Namespace(**args_d2) task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
with pytest.raises(Exception) as excinfo: args_d["label_smoothing"] = 0.1 if task == "translation" else 0
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
elif model == FSMT_TINY:
fsmt = module.model.model
embed_pos = fsmt.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not fsmt.decoder.embed_tokens.weight.requires_grad
# check that embeds are not the same
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
example_batch = load_json(module.output_dir / "text_batch.json")
assert isinstance(example_batch, dict)
assert len(example_batch) >= 4
def test_finetune_extra_model_args(self):
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
args_d.update(
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args
model = BART_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model,
output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
args_d1[p] = 0.5
args = argparse.Namespace(**args_d1)
model = main(args) model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute" for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
def test_finetune_lr_schedulers(): # test models whose config doesn't include the extra_model_args
args_d: dict = CHEAP_ARGS.copy() model = T5_TINY
output_dir = self.get_auto_remove_tmp_dir()
task = "summarization" args_d2 = args_d.copy()
tmp_dir = make_test_data_dir() args_d2.update(
model_name_or_path=model,
model = BART_TINY output_dir=output_dir,
output_dir = tempfile.mkdtemp(prefix="output_1_") )
unsupported_param = "encoder_layerdrop"
args_d.update( args_d2[unsupported_param] = 0.5
data_dir=tmp_dir, args = argparse.Namespace(**args_d2)
model_name_or_path=model, with pytest.raises(Exception) as excinfo:
output_dir=output_dir, model = main(args)
tokenizer_name=None, assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
train_batch_size=2,
eval_batch_size=2, def test_finetune_lr_schedulers(self):
do_predict=False, args_d: dict = CHEAP_ARGS.copy()
task=task,
src_lang="en_XX", task = "summarization"
tgt_lang="ro_RO", tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
freeze_encoder=True,
freeze_embeds=True, model = BART_TINY
) output_dir = self.get_auto_remove_tmp_dir()
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
output_dir=output_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# emulate finetune.py # emulate finetune.py
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = {"--help": True} args = {"--help": True}
# --help test # --help test
with pytest.raises(SystemExit) as excinfo: with pytest.raises(SystemExit) as excinfo:
with CaptureStdout() as cs: with CaptureStdout() as cs:
args = parser.parse_args(args) args = parser.parse_args(args)
assert False, "--help is expected to sys.exit" assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit assert excinfo.type == SystemExit
expected = lightning_base.arg_to_scheduler_metavar expected = lightning_base.arg_to_scheduler_metavar
assert expected in cs.out, "--help is expected to list the supported schedulers" assert expected in cs.out, "--help is expected to list the supported schedulers"
# --lr_scheduler=non_existing_scheduler test # --lr_scheduler=non_existing_scheduler test
unsupported_param = "non_existing_scheduler" unsupported_param = "non_existing_scheduler"
args = {f"--lr_scheduler={unsupported_param}"} args = {f"--lr_scheduler={unsupported_param}"}
with pytest.raises(SystemExit) as excinfo: with pytest.raises(SystemExit) as excinfo:
with CaptureStderr() as cs: with CaptureStderr() as cs:
args = parser.parse_args(args) args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit" assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit assert excinfo.type == SystemExit
expected = f"invalid choice: '{unsupported_param}'" expected = f"invalid choice: '{unsupported_param}'"
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
# --lr_scheduler=existing_scheduler test # --lr_scheduler=existing_scheduler test
supported_param = "cosine" supported_param = "cosine"
args_d1 = args_d.copy() args_d1 = args_d.copy()
args_d1["lr_scheduler"] = supported_param args_d1["lr_scheduler"] = supported_param
args = argparse.Namespace(**args_d1) args = argparse.Namespace(**args_d1)
model = main(args) model = main(args)
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail" assert (
getattr(model.hparams, "lr_scheduler") == supported_param
), f"lr_scheduler={supported_param} shouldn't fail"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment