Unverified Commit 783d7d26 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize examples (#9010)



* Reorganize example folder

* Continue reorganization

* Change requirements for tests

* Final cleanup

* Finish regroup with tests all passing

* Copyright

* Requirements and readme

* Make a full link for the documentation

* Address review comments

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Add symlink

* Reorg again

* Apply suggestions from code review
Co-authored-by: default avatarThomas Wolf <thomwolf@users.noreply.github.com>

* Adapt title

* Update to new strucutre

* Remove test

* Update READMEs
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarThomas Wolf <thomwolf@users.noreply.github.com>
parent 86896de0
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import itertools import itertools
......
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fire import fire
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
......
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fire import fire
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re import re
from filelock import FileLock from filelock import FileLock
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
......
seq2seq/test_data
\ No newline at end of file
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from pathlib import Path from pathlib import Path
...@@ -8,7 +22,6 @@ from torch.utils.data import DataLoader ...@@ -8,7 +22,6 @@ from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir from pack_dataset import pack_data_dir
from parameterized import parameterized from parameterized import parameterized
from save_len_file import save_len_file from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.bart.modeling_bart import shift_tokens_right from transformers.models.bart.modeling_bart import shift_tokens_right
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow
...@@ -17,6 +30,24 @@ from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDat ...@@ -17,6 +30,24 @@ from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDat
BERT_BASE_CASED = "bert-base-cased" BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum" PEGASUS_XSUM = "google/pegasus-xsum"
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
def _dump_articles(path: Path, articles: list):
content = "\n".join(articles)
Path(path).open("w").writelines(content)
def make_test_data_dir(tmp_dir):
for split in ["train", "val", "test"]:
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir
class TestAll(TestCasePlus): class TestAll(TestCasePlus):
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import sys import sys
from unittest.mock import patch from unittest.mock import patch
...@@ -17,11 +31,11 @@ from transformers.trainer_utils import set_seed ...@@ -17,11 +31,11 @@ from transformers.trainer_utils import set_seed
from .finetune_trainer import Seq2SeqTrainingArguments, main from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer from .seq2seq_trainer import Seq2SeqTrainer
from .test_seq2seq_examples import MBART_TINY
set_seed(42) set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
MBART_TINY = "sshleifer/tiny-mbart"
class TestFinetuneTrainer(TestCasePlus): class TestFinetuneTrainer(TestCasePlus):
......
import argparse # Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import os import os
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest
import pytorch_lightning as pl
import torch
import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main
from finetune import SummarizationModule, main
from parameterized import parameterized from parameterized import parameterized
from run_eval import generate_summaries_or_translations, run_generate from run_eval import run_generate
from run_eval_search import run_search from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers.testing_utils import CaptureStdout, TestCasePlus, slow
from transformers.hf_api import HfApi from utils import ROUGE_KEYS
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger() logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"max_tokens_per_batch": None,
"supervise_forward": True,
"normalize_hidden": True,
"label_smoothing": 0.2,
"eval_max_gen_length": None,
"eval_beams": 1,
"val_metric": "loss",
"save_top_k": 1,
"adafactor": True,
"early_stopping_patience": 2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"accumulate_grad_batches": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"do_lower_case": False,
"learning_rate": 0.3,
"lr_scheduler": "linear",
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"max_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"freeze_encoder": False,
"auto_scale_batch_size": False,
"overwrite_output_dir": False,
"student": None,
}
def _dump_articles(path: Path, articles: list): def _dump_articles(path: Path, articles: list):
...@@ -98,187 +34,15 @@ def _dump_articles(path: Path, articles: list): ...@@ -98,187 +34,15 @@ def _dump_articles(path: Path, articles: list):
Path(path).open("w").writelines(content) Path(path).open("w").writelines(content)
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random" T5_TINY = "patrickvonplaten/t5-tiny-random"
T5_TINIER = "sshleifer/t5-tinier-random"
BART_TINY = "sshleifer/bart-tiny-random" BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart" MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
FSMT_TINY = "stas/tiny-wmt19-en-de"
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(tmp_dir):
for split in ["train", "val", "test"]:
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir
class TestSummarizationDistiller(TestCasePlus):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@slow
@require_torch_gpu
def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
model_list = HfApi().model_list()
org = "sshleifer"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
failures = []
for m in model_ids:
if m in allowed_to_be_broken:
continue
try:
AutoConfig.from_pretrained(m)
except Exception:
failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates)
def test_distill_checkpointing_with_teacher(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
max_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
)
model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), 2)
examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
out_path = tempfile.mktemp() # XXX: not being cleaned up
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
self.assertTrue(Path(out_path).exists())
out_path_new = self.get_auto_remove_tmp_dir()
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
def test_loss_fn(self):
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
model_computed_loss = model(
input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False
).loss
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
smoothed_loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
)
with self.assertRaises(AssertionError):
# TODO: understand why this breaks
self.assertEqual(nll_loss, model_computed_loss)
def test_distill_mbart(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
task="translation",
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
tokenizer_name=MBART_TINY,
teacher=MBART_TINY,
src_lang="en_XX",
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
assert len(all_files) > 2
self.assertEqual(len(transformer_ckpts), 2)
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher=T5_TINY,
model_name_or_path=T5_TINY,
tokenizer_name=T5_TINY,
)
self._test_distiller_cli(updates)
def test_distill_different_base_models(self):
updates = dict(
teacher=T5_TINY,
student=T5_TINIER,
model_name_or_path=T5_TINIER,
tokenizer_name=T5_TINIER,
)
self._test_distiller_cli(updates)
def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d))
if not check_contents:
return model
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
ckpt_files = [p for p in contents if p.endswith("ckpt")]
assert len(ckpt_files) > 0
self.assertIn("test_generations.txt", contents)
self.assertIn("test_results.txt", contents)
metrics = load_json(model.metrics_save_path)
last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1)
return model
class TestTheRest(TestCasePlus): class TestTheRest(TestCasePlus):
def run_eval_tester(self, model): def run_eval_tester(self, model):
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source" input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
...@@ -365,167 +129,3 @@ class TestTheRest(TestCasePlus): ...@@ -365,167 +129,3 @@ class TestTheRest(TestCasePlus):
assert w not in cs.out assert w not in cs.out
assert Path(output_file_name).exists() assert Path(output_file_name).exists()
os.remove(Path(output_file_name)) os.remove(Path(output_file_name))
@parameterized.expand(
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
)
def test_finetune(self, model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
elif model == FSMT_TINY:
fsmt = module.model.model
embed_pos = fsmt.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not fsmt.decoder.embed_tokens.weight.requires_grad
# check that embeds are not the same
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
example_batch = load_json(module.output_dir / "text_batch.json")
assert isinstance(example_batch, dict)
assert len(example_batch) >= 4
def test_finetune_extra_model_args(self):
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
args_d.update(
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args
model = BART_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model,
output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
args_d1[p] = 0.5
args = argparse.Namespace(**args_d1)
model = main(args)
for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
# test models whose config doesn't include the extra_model_args
model = T5_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model,
output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
args = argparse.Namespace(**args_d2)
with pytest.raises(Exception) as excinfo:
model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_finetune_lr_schedulers(self):
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
model = BART_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
output_dir=output_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# emulate finetune.py
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = {"--help": True}
# --help test
with pytest.raises(SystemExit) as excinfo:
with CaptureStdout() as cs:
args = parser.parse_args(args)
assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit
expected = lightning_base.arg_to_scheduler_metavar
assert expected in cs.out, "--help is expected to list the supported schedulers"
# --lr_scheduler=non_existing_scheduler test
unsupported_param = "non_existing_scheduler"
args = {f"--lr_scheduler={unsupported_param}"}
with pytest.raises(SystemExit) as excinfo:
with CaptureStderr() as cs:
args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit
expected = f"invalid choice: '{unsupported_param}'"
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
# --lr_scheduler=existing_scheduler test
supported_param = "cosine"
args_d1 = args_d.copy()
args_d1["lr_scheduler"] = supported_param
args = argparse.Namespace(**args_d1)
model = main(args)
assert (
getattr(model.hparams, "lr_scheduler") == supported_param
), f"lr_scheduler={supported_param} shouldn't fail"
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module. # as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
import os import os
import sys import sys
from transformers.testing_utils import ( from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, require_torch_gpu, slow
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
from .test_seq2seq_examples import CHEAP_ARGS, make_test_data_dir
from .utils import load_json from .utils import load_json
...@@ -21,73 +27,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus): ...@@ -21,73 +27,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
def setUpClass(cls): def setUpClass(cls):
return cls return cls
@require_torch_multi_gpu
def test_multi_gpu(self):
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
overwrite_output_dir=True,
sortish_sampler=True,
)
self._test_distiller_cli_fork(updates, check_contents=False)
def _test_distiller_cli_fork(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
def convert(k, v):
if k in ["tgt_suffix", "server_ip", "server_port", "out", "n_tpu_cores"]:
return ""
if v is False or v is None:
return ""
if v is True: # or len(str(v))==0:
return f"--{k}"
return f"--{k}={v}"
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
cmd = [sys.executable, f"{self.test_file_dir}/distillation.py"] + cli_args
execute_subprocess_async(cmd, env=self.get_env())
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
ckpt_files = [p for p in contents if p.endswith("ckpt")]
assert len(ckpt_files) > 0
self.assertIn("test_generations.txt", contents)
self.assertIn("test_results.txt", contents)
# get the following from the module, (we don't have access to `model` here)
metrics_save_path = os.path.join(output_dir, "metrics.json")
val_metric = "rouge2"
metrics = load_json(metrics_save_path)
# {'test': [{'test_avg_loss': 10.63731575012207, 'test_avg_rouge1': 0.0, 'test_avg_rouge2': 0.0, 'test_avg_rougeL': 0.0, 'test_avg_gen_time': 0.1822289228439331, 'test_avg_gen_len': 142.0, 'step_count': 1}]}
print(metrics)
last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
self.assertEqual(len(metrics["test"]), 1)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_distributed_eval(self): def test_distributed_eval(self):
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import tempfile import tempfile
import unittest import unittest
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export WANDB_PROJECT=distil-marian export WANDB_PROJECT=distil-marian
export BS=64 export BS=64
export GAS=1 export GAS=1
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export WANDB_PROJECT=distil-marian export WANDB_PROJECT=distil-marian
export BS=64 export BS=64
export m=sshleifer/student_marian_en_ro_6_3 export m=sshleifer/student_marian_en_ro_6_3
......
#!/usr/bin/env bash # Copyright 2020 The HuggingFace Team. All rights reserved.
export PYTHONPATH="../":"${PYTHONPATH}" #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export WANDB_PROJECT=distilbart-trainer
export BS=32 export BS=32
export GAS=1 export m=sshleifer/student_cnn_12_6
export tok=facebook/bart-large
export MAX_TGT_LEN=142
python finetune.py \ python finetune_trainer.py \
--model_name_or_path $m --tokenizer_name $tok \
--data_dir cnn_dm \
--output_dir distilbart-cnn-12-6 --overwrite_output_dir \
--learning_rate=3e-5 \ --learning_rate=3e-5 \
--warmup_steps 500 --sortish_sampler \
--fp16 \ --fp16 \
--gpus 1 \
--do_train \
--do_predict \
--val_check_interval 0.25 \
--n_val 500 \ --n_val 500 \
--num_train_epochs 2 \ --gradient_accumulation_steps=1 \
--freeze_encoder --freeze_embeds --data_dir cnn_dm \ --per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
--max_target_length 142 --val_max_target_length=142 \ --freeze_encoder --freeze_embeds \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \ --num_train_epochs=2 \
--model_name_or_path sshleifer/student_cnn_12_6 \ --save_steps 3000 --eval_steps 3000 \
--tokenizer_name facebook/bart-large \ --logging_first_step \
--warmup_steps 500 \ --max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--output_dir distilbart-cnn-12-6 \ --do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --sortish_sampler \
"$@" "$@"
#!/usr/bin/env bash # Copyright 2020 The HuggingFace Team. All rights reserved.
export PYTHONPATH="../":"${PYTHONPATH}" #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python finetune.py \ python finetune_trainer.py \
--model_name_or_path=facebook/mbart-large-cc25 \
--data_dir $ENRO_DIR \
--output_dir mbart_cc25_enro --overwrite_output_dir \
--learning_rate=3e-5 \ --learning_rate=3e-5 \
--warmup_steps 500 \
--fp16 \ --fp16 \
--do_train \ --label_smoothing 0.1 \
--val_check_interval=0.25 \
--adam_eps 1e-06 \ --adam_eps 1e-06 \
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \ --src_lang en_XX --tgt_lang ro_RO \
--data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--train_batch_size=$BS --eval_batch_size=$BS \
--task translation \
--warmup_steps 500 \
--freeze_embeds \ --freeze_embeds \
--model_name_or_path=facebook/mbart-large-cc25 \ --per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
--max_source_length 128 --max_target_length 128 \
--val_max_target_length 128 --test_max_target_length 128 \
--sortish_sampler \
--num_train_epochs 6 \
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --logging_first_step \
--task translation \
"$@" "$@"
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools import itertools
import json import json
import linecache import linecache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment