"...resnet50_tensorflow.git" did not exist on "93e0022d69993032d211ea5e786cca92eda26dc6"
Unverified Commit 9f7b2b24 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)

parent dc552b9b
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import argparse import argparse
import os import os
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
...@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main ...@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import load_json from utils import load_json
...@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY ...@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow class TestAll(TestCasePlus):
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @slow
def test_model_download(): @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" def test_model_download(self):
BartForConditionalGeneration.from_pretrained(MODEL_NAME) """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel.from_pretrained(MARIAN_MODEL) BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120) @timeout_decorator.timeout(120)
@slow @slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
def test_train_mbart_cc25_enro_script(): def test_train_mbart_cc25_enro_script(self):
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"--fp16_opt_level=O1": "", "--fp16_opt_level=O1": "",
"$MAX_LEN": 128, "$MAX_LEN": 128,
"$BS": 4, "$BS": 4,
"$GAS": 1, "$GAS": 1,
"$ENRO_DIR": data_dir, "$ENRO_DIR": data_dir,
"facebook/mbart-large-cc25": MODEL_NAME, "facebook/mbart-large-cc25": MODEL_NAME,
# Download is 120MB in previous test. # Download is 120MB in previous test.
"val_check_interval=0.25": "val_check_interval=1.0", "val_check_interval=0.25": "val_check_interval=1.0",
} }
# Clean up bash script # Clean up bash script
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items(): for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="output_mbart") output_dir = self.get_auto_remove_tmp_dir()
bash_script = bash_script.replace("--fp16 ", "") bash_script = bash_script.replace("--fp16 ", "")
testargs = ( testargs = (
["finetune.py"] ["finetune.py"]
+ bash_script.split() + bash_script.split()
+ [ + [
f"--output_dir={output_dir}", f"--output_dir={output_dir}",
"--gpus=1", "--gpus=1",
"--learning_rate=3e-1", "--learning_rate=3e-1",
"--warmup_steps=0", "--warmup_steps=0",
"--val_check_interval=1.0", "--val_check_interval=1.0",
"--tokenizer_name=facebook/mbart-large-en-ro", "--tokenizer_name=facebook/mbart-large-en-ro",
] ]
) )
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args() args = parser.parse_args()
args.do_predict = False args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu # assert args.gpus == gpus THIS BREAKS for multigpu
model = main(args) model = main(args)
# Check metrics # Check metrics
metrics = load_json(model.metrics_save_path) metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0] first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1] last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check assert (
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
assert last_step_stats["val_avg_gen_time"] >= 0.01 ) # +1 accounts for val_sanity_check
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing assert last_step_stats["val_avg_gen_time"] >= 0.01
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
# check lightning ckpt can be loaded and has a reasonable statedict assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
contents = os.listdir(output_dir)
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] # check lightning ckpt can be loaded and has a reasonable statedict
full_path = os.path.join(args.output_dir, ckpt_path) contents = os.listdir(output_dir)
ckpt = torch.load(full_path, map_location="cpu") ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" full_path = os.path.join(args.output_dir, ckpt_path)
assert expected_key in ckpt["state_dict"] ckpt = torch.load(full_path, map_location="cpu")
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
# TODO: turn on args.do_predict when PL bug fixed. assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
if args.do_predict:
contents = {os.path.basename(p) for p in contents} # TODO: turn on args.do_predict when PL bug fixed.
assert "test_generations.txt" in contents if args.do_predict:
assert "test_results.txt" in contents contents = {os.path.basename(p) for p in contents}
# assert len(metrics["val"]) == desired_n_evals assert "test_generations.txt" in contents
assert len(metrics["test"]) == 1 assert "test_results.txt" in contents
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
@timeout_decorator.timeout(600)
@slow @timeout_decorator.timeout(600)
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @slow
def test_opus_mt_distill_script(): @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
data_dir = "examples/seq2seq/test_data/wmt_en_ro" def test_opus_mt_distill_script(self):
env_vars_to_replace = { data_dir = "examples/seq2seq/test_data/wmt_en_ro"
"--fp16_opt_level=O1": "", env_vars_to_replace = {
"$MAX_LEN": 128, "--fp16_opt_level=O1": "",
"$BS": 16, "$MAX_LEN": 128,
"$GAS": 1, "$BS": 16,
"$ENRO_DIR": data_dir, "$GAS": 1,
"$m": "sshleifer/student_marian_en_ro_6_1", "$ENRO_DIR": data_dir,
"val_check_interval=0.25": "val_check_interval=1.0", "$m": "sshleifer/student_marian_en_ro_6_1",
} "val_check_interval=0.25": "val_check_interval=1.0",
}
# Clean up bash script
bash_script = ( # Clean up bash script
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() bash_script = (
) Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") )
bash_script = bash_script.replace("--fp16 ", " ") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) for k, v in env_vars_to_replace.items():
output_dir = tempfile.mkdtemp(prefix="marian_output") bash_script = bash_script.replace(k, str(v))
bash_script = bash_script.replace("--fp16", "") output_dir = self.get_auto_remove_tmp_dir()
epochs = 6 bash_script = bash_script.replace("--fp16", "")
testargs = ( epochs = 6
["distillation.py"] testargs = (
+ bash_script.split() ["distillation.py"]
+ [ + bash_script.split()
f"--output_dir={output_dir}", + [
"--gpus=1", f"--output_dir={output_dir}",
"--learning_rate=1e-3", "--gpus=1",
f"--num_train_epochs={epochs}", "--learning_rate=1e-3",
"--warmup_steps=10", f"--num_train_epochs={epochs}",
"--val_check_interval=1.0", "--warmup_steps=10",
] "--val_check_interval=1.0",
) ]
with patch.object(sys, "argv", testargs): )
parser = argparse.ArgumentParser() with patch.object(sys, "argv", testargs):
parser = pl.Trainer.add_argparse_args(parser) parser = argparse.ArgumentParser()
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args() parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args.do_predict = False args = parser.parse_args()
# assert args.gpus == gpus THIS BREAKS for multigpu args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = distill_main(args)
model = distill_main(args)
# Check metrics
metrics = load_json(model.metrics_save_path) # Check metrics
first_step_stats = metrics["val"][0] metrics = load_json(model.metrics_save_path)
last_step_stats = metrics["val"][-1] first_step_stats = metrics["val"][0]
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01
assert last_step_stats["val_avg_gen_time"] >= 0.01
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
# check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir) # check lightning ckpt can be loaded and has a reasonable statedict
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] contents = os.listdir(output_dir)
full_path = os.path.join(args.output_dir, ckpt_path) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
ckpt = torch.load(full_path, map_location="cpu") full_path = os.path.join(args.output_dir, ckpt_path)
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" ckpt = torch.load(full_path, map_location="cpu")
assert expected_key in ckpt["state_dict"] expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO: turn on args.do_predict when PL bug fixed.
if args.do_predict: # TODO: turn on args.do_predict when PL bug fixed.
contents = {os.path.basename(p) for p in contents} if args.do_predict:
assert "test_generations.txt" in contents contents = {os.path.basename(p) for p in contents}
assert "test_results.txt" in contents assert "test_generations.txt" in contents
# assert len(metrics["val"]) == desired_n_evals assert "test_results.txt" in contents
assert len(metrics["test"]) == 1 # assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
import os import os
import tempfile
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -7,11 +6,12 @@ import pytest ...@@ -7,11 +6,12 @@ import pytest
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir from pack_dataset import pack_data_dir
from parameterized import parameterized
from save_len_file import save_len_file from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
...@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased" ...@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum" PEGASUS_XSUM = "google/pegasus-xsum"
@slow class TestAll(TestCasePlus):
@pytest.mark.parametrize( @parameterized.expand(
"tok_name", [
[ MBART_TINY,
MBART_TINY, MARIAN_TINY,
MARIAN_TINY, T5_TINY,
T5_TINY, BART_TINY,
BART_TINY, PEGASUS_XSUM,
PEGASUS_XSUM, ],
],
)
def test_seq2seq_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
max_src_len = 4
max_tgt_len = 8
assert max_len_target > max_src_len # Will be truncated
assert max_len_source > max_src_len # Will be truncated
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=max_src_len,
max_target_length=max_tgt_len, # ignored
src_lang=src_lang,
tgt_lang=tgt_lang,
) )
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) @slow
for batch in dataloader: def test_seq2seq_dataset_truncation(self, tok_name):
assert isinstance(batch, dict) tokenizer = AutoTokenizer.from_pretrained(tok_name)
assert batch["attention_mask"].shape == batch["input_ids"].shape tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
# show that articles were trimmed. max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
assert batch["input_ids"].shape[1] == max_src_len max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
# show that targets are the same len max_src_len = 4
assert batch["labels"].shape[1] == max_tgt_len max_tgt_len = 8
if tok_name != MBART_TINY: assert max_len_target > max_src_len # Will be truncated
continue assert max_len_source > max_src_len # Will be truncated
# check language codes in correct place src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
break # No need to test every batch
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir())
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()}
packed_examples = save_dir.joinpath("train.source").open().readlines()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert len(packed_examples) < len(orig_examples)
assert len(packed_examples) == 1
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
def test_dynamic_batch_size():
if not FAIRSEQ_AVAILABLE:
return
ds, max_tokens, tokenizer = _get_dataset(max_len=64)
required_batch_size_multiple = 64
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
batch_sizes = [len(x) for x in batch_sampler]
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
assert sum(batch_sizes) == len(ds) # no dropped or added examples
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
failures = []
num_src_per_batch = []
for batch in data_loader:
src_shape = batch["input_ids"].shape
bs = src_shape[0]
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
num_src_tokens = np.product(batch["input_ids"].shape)
num_src_per_batch.append(num_src_tokens)
if num_src_tokens > (max_tokens * 1.1):
failures.append(num_src_tokens)
assert num_src_per_batch[0] == max(num_src_per_batch)
if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches")
def test_sortish_sampler_reduces_padding():
ds, _, tokenizer = _get_dataset(max_len=512)
bs = 2
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
pad = tokenizer.pad_token_id
def count_pad_tokens(data_loader, k="input_ids"):
return [batch[k].eq(pad).sum().item() for batch in data_loader]
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
assert len(sortish_dl) == len(naive_dl)
def _get_dataset(n_obs=1000, max_len=128):
if os.getenv("USE_REAL_DATA", False):
data_dir = "examples/seq2seq/wmt_en_ro"
max_tokens = max_len * 2 * 64
if not Path(data_dir).joinpath("train.len").exists():
save_len_file(MARIAN_TINY, data_dir)
else:
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
max_tokens = max_len * 4
save_len_file(MARIAN_TINY, data_dir)
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
ds = Seq2SeqDataset(
tokenizer,
data_dir=data_dir,
type_path="train",
max_source_length=max_len,
max_target_length=max_len,
n_obs=n_obs,
)
return ds, max_tokens, tokenizer
def test_distributed_sortish_sampler_splits_indices_between_procs():
ds, max_tokens, tokenizer = _get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@pytest.mark.parametrize(
"tok_name",
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset( train_dataset = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir=make_test_data_dir(), data_dir=tmp_dir,
type_path="train", type_path="train",
max_source_length=4, max_source_length=max_src_len,
max_target_length=8, max_target_length=max_tgt_len, # ignored
src_lang="EN", src_lang=src_lang,
tgt_lang="FR", tgt_lang=tgt_lang,
) )
kwargs = train_dataset.dataset_kwargs dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
assert "src_lang" in kwargs and "tgt_lang" in kwargs for batch in dataloader:
else: assert isinstance(batch, dict)
train_dataset = Seq2SeqDataset( assert batch["attention_mask"].shape == batch["input_ids"].shape
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8 # show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["labels"].shape[1] == max_tgt_len
if tok_name != MBART_TINY:
continue
# check language codes in correct place
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
break # No need to test every batch
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(self, tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
) )
kwargs = train_dataset.dataset_kwargs dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs for batch in dataloader:
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0 assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
def test_pack_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()}
packed_examples = save_dir.joinpath("train.source").open().readlines()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert len(packed_examples) < len(orig_examples)
assert len(packed_examples) == 1
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
def test_dynamic_batch_size(self):
if not FAIRSEQ_AVAILABLE:
return
ds, max_tokens, tokenizer = self._get_dataset(max_len=64)
required_batch_size_multiple = 64
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
batch_sizes = [len(x) for x in batch_sampler]
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
assert sum(batch_sizes) == len(ds) # no dropped or added examples
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
failures = []
num_src_per_batch = []
for batch in data_loader:
src_shape = batch["input_ids"].shape
bs = src_shape[0]
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
num_src_tokens = np.product(batch["input_ids"].shape)
num_src_per_batch.append(num_src_tokens)
if num_src_tokens > (max_tokens * 1.1):
failures.append(num_src_tokens)
assert num_src_per_batch[0] == max(num_src_per_batch)
if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches")
def test_sortish_sampler_reduces_padding(self):
ds, _, tokenizer = self._get_dataset(max_len=512)
bs = 2
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
pad = tokenizer.pad_token_id
def count_pad_tokens(data_loader, k="input_ids"):
return [batch[k].eq(pad).sum().item() for batch in data_loader]
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
assert len(sortish_dl) == len(naive_dl)
def _get_dataset(self, n_obs=1000, max_len=128):
if os.getenv("USE_REAL_DATA", False):
data_dir = "examples/seq2seq/wmt_en_ro"
max_tokens = max_len * 2 * 64
if not Path(data_dir).joinpath("train.len").exists():
save_len_file(MARIAN_TINY, data_dir)
else:
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
max_tokens = max_len * 4
save_len_file(MARIAN_TINY, data_dir)
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
ds = Seq2SeqDataset(
tokenizer,
data_dir=data_dir,
type_path="train",
max_source_length=max_len,
max_target_length=max_len,
n_obs=n_obs,
)
return ds, max_tokens, tokenizer
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
ds, max_tokens, tokenizer = self._get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@parameterized.expand(
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train",
max_source_length=4,
max_target_length=8,
src_lang="EN",
tgt_lang="FR",
)
kwargs = train_dataset.dataset_kwargs
assert "src_lang" in kwargs and "tgt_lang" in kwargs
else:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train",
max_source_length=4,
max_target_length=8,
)
kwargs = train_dataset.dataset_kwargs
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
import os import os
import sys import sys
import tempfile
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
...@@ -15,72 +14,71 @@ set_seed(42) ...@@ -15,72 +14,71 @@ set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
def test_finetune_trainer(): class TestFinetuneTrainer(TestCasePlus):
output_dir = run_trainer(1, "12", MBART_TINY, 1) def test_finetune_trainer(self):
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
first_step_stats = eval_metrics[0] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
assert "eval_bleu" in first_step_stats first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats
@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
@slow # Check metrics
def test_finetune_trainer_slow(): logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
# There is a missing call to __init__process_group somewhere eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
# Check metrics assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history assert isinstance(last_step_stats["eval_bleu"], float)
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing # test if do_predict saves generations and metrics
assert isinstance(last_step_stats["eval_bleu"], float) contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.json" in contents
# test if do_predict saves generations and metrics def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
contents = os.listdir(output_dir) data_dir = "examples/seq2seq/test_data/wmt_en_ro"
contents = {os.path.basename(p) for p in contents} output_dir = self.get_auto_remove_tmp_dir()
assert "test_generations.txt" in contents argv = f"""
assert "test_results.json" in contents --model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): return output_dir
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = tempfile.mkdtemp(prefix="test_output")
argv = f"""
--model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
return output_dir
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment