Unverified Commit 61b7ba93 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Marian distill scripts + integration test (#6799)

parent 02d09c8f
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=dmar
# export MAX_LEN=128
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.25 \
--teacher Helsinki-NLP/opus-mt-en-ro --data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--student_decoder_layers 3 --student_encoder_layers 6 \
--freeze_encoder --freeze_embeds \
--model_name_or_path IGNORED \
--alpha_hid=3. \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level O1 --task translation \
"$@"
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=dmar
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --no_teacher \
--val_check_interval 0.25 \
--data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--freeze_encoder --freeze_embeds \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name $m --model_name_or_path $m \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level=O1 --task translation \
"$@"
...@@ -10,9 +10,10 @@ import pytorch_lightning as pl ...@@ -10,9 +10,10 @@ import pytorch_lightning as pl
import timeout_decorator import timeout_decorator
import torch import torch
from transformers import BartForConditionalGeneration from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import slow
from .distillation import BartSummarizationDistiller, distill_main
from .finetune import SummarizationModule, main from .finetune import SummarizationModule, main
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from .utils import load_json from .utils import load_json
...@@ -20,6 +21,7 @@ from .utils import load_json ...@@ -20,6 +21,7 @@ from .utils import load_json
MODEL_NAME = MBART_TINY MODEL_NAME = MBART_TINY
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1" # TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow @slow
...@@ -27,6 +29,7 @@ MODEL_NAME = MBART_TINY ...@@ -27,6 +29,7 @@ MODEL_NAME = MBART_TINY
def test_model_download(): def test_model_download():
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration.from_pretrained(MODEL_NAME) BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120) @timeout_decorator.timeout(120)
...@@ -35,34 +38,30 @@ def test_model_download(): ...@@ -35,34 +38,30 @@ def test_model_download():
def test_train_mbart_cc25_enro_script(): def test_train_mbart_cc25_enro_script():
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"$MAX_LEN": 200, "--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 4, "$BS": 4,
"$GAS": 1, "$GAS": 1,
"$ENRO_DIR": data_dir, "$ENRO_DIR": data_dir,
"facebook/mbart-large-cc25": MODEL_NAME, "facebook/mbart-large-cc25": MODEL_NAME,
# 1 encoder and 1 decoder layer from finetuned mbart en-ro. Should be able to start >0 and improve quickly. # Download is 120MB in previous test.
# Download is 600MB in previous test.
"val_check_interval=0.25": "val_check_interval=1.0", "val_check_interval=0.25": "val_check_interval=1.0",
} }
# Clean up bash script # Clean up bash script
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace("$@", "") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items(): for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="output") output_dir = tempfile.mkdtemp(prefix="output_mbart")
if CUDA_AVAILABLE: bash_script = bash_script.replace("--fp16 ", "")
gpus = 1 # torch.cuda.device_count()
else:
gpus = 0
bash_script = bash_script.replace("--fp16", "")
testargs = ( testargs = (
["finetune.py"] ["finetune.py"]
+ bash_script.split() + bash_script.split()
+ [ + [
f"--output_dir={output_dir}", f"--output_dir={output_dir}",
f"--gpus={gpus}", "--gpus=1",
"--learning_rate=3e-1", "--learning_rate=3e-1",
"--warmup_steps=0", "--warmup_steps=0",
"--val_check_interval=1.0", "--val_check_interval=1.0",
...@@ -82,7 +81,86 @@ def test_train_mbart_cc25_enro_script(): ...@@ -82,7 +81,86 @@ def test_train_mbart_cc25_enro_script():
metrics = load_json(model.metrics_save_path) metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0] first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1] last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
# check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir)
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
full_path = os.path.join(args.output_dir, ckpt_path)
ckpt = torch.load(full_path, map_location="cpu")
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO(SS): turn on args.do_predict when PL bug fixed.
if args.do_predict:
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.txt" in contents
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
@timeout_decorator.timeout(600)
@slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
def test_opus_mt_distill_script():
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = {
"--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 16,
"$GAS": 1,
"$ENRO_DIR": data_dir,
"$m": "sshleifer/student_marian_en_ro_6_1",
"val_check_interval=0.25": "val_check_interval=1.0",
}
# Clean up bash script
bash_script = (
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
)
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="marian_output")
bash_script = bash_script.replace("--fp16", "")
epochs = 6
testargs = (
["distillation.py"]
+ bash_script.split()
+ [
f"--output_dir={output_dir}",
"--gpus=1",
"--learning_rate=1e-3",
f"--num_train_epochs={epochs}",
"--warmup_steps=10",
"--val_check_interval=1.0",
]
)
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = distill_main(args)
# Check metrics
metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01 assert last_step_stats["val_avg_gen_time"] >= 0.01
......
...@@ -114,7 +114,9 @@ class ExamplesTests(TestCasePlus): ...@@ -114,7 +114,9 @@ class ExamplesTests(TestCasePlus):
--max_seq_length=128 --max_seq_length=128
""".split() """.split()
if torch.cuda.is_available(): if torch.cuda.is_available():
testargs += ["--fp16", "--gpus=1"] testargs += ["--gpus=1"]
if is_cuda_and_apex_avaliable():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_pl_glue.main() result = run_pl_glue.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment