"examples/vscode:/vscode.git/clone" did not exist on "466115b2797b6e01cce5c979bd8e20b3a1e04746"
Unverified Commit f5c2a122 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Upgrade examples to pl=0.8.1(#5146)

parent 06b60c8b
...@@ -8,6 +8,7 @@ from typing import Any, Dict ...@@ -8,6 +8,7 @@ from typing import Any, Dict
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from transformers import ( from transformers import (
AdamW, AdamW,
...@@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule): ...@@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
model=None, model=None,
**config_kwargs **config_kwargs
): ):
"Initialize a model." """Initialize a model, tokenizer and config."""
super().__init__() super().__init__()
self.hparams = hparams self.hparams = hparams # TODO: move to self.save_hyperparameters()
self.step_count = 0 self.step_count = 0
self.tfmr_ckpts = {} self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir) self.output_dir = Path(self.hparams.output_dir)
...@@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule): ...@@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
) )
else: else:
self.tokenizer: PreTrainedTokenizer = tokenizer self.tokenizer: PreTrainedTokenizer = tokenizer
if model is None:
self.model_type = MODEL_MODES[mode] self.model_type = MODEL_MODES[mode]
if model is None:
self.model = self.model_type.from_pretrained( self.model = self.model_type.from_pretrained(
self.hparams.model_name_or_path, self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path), from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
...@@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule): ...@@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
cache_dir=cache_dir, cache_dir=cache_dir,
) )
else: else:
self.model_type = None
self.model = model self.model = model
def load_hf_checkpoint(self, *args, **kwargs): def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs) self.model = self.model_type.from_pretrained(*args, **kwargs)
def is_logger(self):
return self.trainer.proc_rank <= 0
def configure_optimizers(self): def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)" "Prepare optimizer and schedule (linear warmup and decay)"
model = self.model model = self.model
no_decay = ["bias", "LayerNorm.weight"] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule): ...@@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
self.opt = optimizer self.opt = optimizer
return [optimizer] return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if self.trainer.use_tpu:
xm.optimizer_step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step()
def get_tqdm_dict(self):
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def test_step(self, batch, batch_nb): def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb) return self.validation_step(batch, batch_nb)
def test_end(self, outputs): def test_epoch_end(self, outputs):
return self.validation_end(outputs) return self.validation_end(outputs)
def train_dataloader(self): def train_dataloader(self):
...@@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument( parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
) )
...@@ -217,21 +200,19 @@ class BaseTransformer(pl.LightningModule): ...@@ -217,21 +200,19 @@ class BaseTransformer(pl.LightningModule):
class LoggingCallback(pl.Callback): class LoggingCallback(pl.Callback):
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Validation results *****") rank_zero_info("***** Validation results *****")
if pl_module.is_logger():
metrics = trainer.callback_metrics metrics = trainer.callback_metrics
# Log results # Log results
for key in sorted(metrics): for key in sorted(metrics):
if key not in ["log", "progress_bar"]: if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key]))) rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****") logger.info("***** Test results *****")
if pl_module.is_logger():
metrics = trainer.callback_metrics metrics = trainer.callback_metrics
# Log and save results to file # Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer: with open(output_test_results_file, "w") as writer:
......
...@@ -5,7 +5,7 @@ psutil ...@@ -5,7 +5,7 @@ psutil
sacrebleu sacrebleu
rouge-score rouge-score
tensorflow_datasets tensorflow_datasets
pytorch-lightning==0.7.6 pytorch-lightning==0.8.1
matplotlib matplotlib
git-python==1.0.3 git-python==1.0.3
faiss faiss
......
...@@ -19,12 +19,11 @@ logger = logging.getLogger(__name__) ...@@ -19,12 +19,11 @@ logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback): class Seq2SeqLoggingCallback(pl.Callback):
@rank_zero_only
def _write_logs( def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
) -> None: ) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
if not pl_module.is_logger():
return
metrics = trainer.callback_metrics metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
# Log results # Log results
......
...@@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule): ...@@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule):
class T5SummarizationDistiller(SummarizationDistiller): class T5SummarizationDistiller(SummarizationDistiller):
def pre_init(self, hparams): def pre_init(self, hparams):
raise NotImplementedError("T5 Distillation does not work yet")
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher) teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
n_layer = hparams.student_decoder_layers n_layer = hparams.student_decoder_layers
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
......
...@@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer): ...@@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer):
if self.hparams.freeze_encoder: if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5 freeze_params(self.model.model.encoder) # TODO: this will break for t5
self.hparams.git_sha = get_git_info()["repo_sha"] self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu self.num_workers = hparams.num_workers
def freeze_embeds(self): def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...@@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer): ...@@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer):
def validation_step(self, batch, batch_idx) -> Dict: def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch) return self._generative_step(batch)
def validation_end(self, outputs, prefix="val") -> Dict: def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1 self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"] loss = losses["loss"]
...@@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer): ...@@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer):
self.metrics[prefix].append(metrics) self.metrics[prefix].append(metrics)
pickle_save(self.metrics, self.metrics_save_path) pickle_save(self.metrics, self.metrics_save_path)
def _generative_step(self, batch): def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# TODO(SS): task specific params
t0 = time.time() t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,) generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
gen_time = time.time() - t0 gen_time = time.time() - t0 / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids) preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y) target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch) loss_tensors = self._step(batch)
...@@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer): ...@@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer):
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
return self._generative_step(batch) return self._generative_step(batch)
def test_end(self, outputs):
return self.validation_end(outputs, prefix="test")
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt") return self.validation_epoch_end(outputs, prefix="test")
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
# write predictions and targets for later rouge evaluation.
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
for output_batch in outputs:
p_writer.writelines(s + "\n" for s in output_batch["preds"])
t_writer.writelines(s + "\n" for s in output_batch["target"])
p_writer.close()
t_writer.close()
return self.test_end(outputs)
def validation_epoch_end(self, outputs):
self.validation_end(outputs, "val")
def get_dataset(self, type_path) -> SummarizationDataset: def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path] n_obs = self.n_obs[type_path]
...@@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule: ...@@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule:
logger=logger, logger=logger,
# TODO: early stopping callback seems messed up # TODO: early stopping callback seems messed up
) )
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict: if not args.do_predict:
return model return model
......
...@@ -7,6 +7,5 @@ python distillation.py \ ...@@ -7,6 +7,5 @@ python distillation.py \
--learning_rate=3e-4 \ --learning_rate=3e-4 \
--do_train \ --do_train \
--do_predict \ --do_predict \
--fp16 \
--val_check_interval 0.1 \ --val_check_interval 0.1 \
$@ $@
...@@ -26,6 +26,7 @@ def generate_summaries( ...@@ -26,6 +26,7 @@ def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
) -> None: ) -> None:
fout = Path(out_file).open("w", encoding="utf-8") fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16: if fp16:
model = model.half() model = model.half()
......
...@@ -24,6 +24,7 @@ logger = logging.getLogger() ...@@ -24,6 +24,7 @@ logger = logging.getLogger()
FP16_EVER = False FP16_EVER = False
CHEAP_ARGS = { CHEAP_ARGS = {
"logger": "default", "logger": "default",
"num_workers": 2,
"alpha_hid": 0, "alpha_hid": 0,
"freeze_embeds": True, "freeze_embeds": True,
"enc_only": False, "enc_only": False,
...@@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list): ...@@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
f.write("\n".join(articles)) f.write("\n".join(articles))
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute() MSG = "T5 is broken at the moment"
T5_TINY = "patrickvonplaten/t5-tiny-random"
def make_test_data_dir(): def make_test_data_dir():
...@@ -92,7 +94,6 @@ def make_test_data_dir(): ...@@ -92,7 +94,6 @@ def make_test_data_dir():
return tmp_dir return tmp_dir
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
class TestSummarizationDistiller(unittest.TestCase): class TestSummarizationDistiller(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
freeze_encoder=True, freeze_encoder=True,
gpus=2, gpus=2,
sortish_sampler=False, sortish_sampler=False,
)
self._bart_distiller_cli(updates)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_fp16(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
alpha_hid=3.0,
freeze_encoder=True,
gpus=1,
fp16=FP16_EVER,
fp16_opt_level="O1", fp16_opt_level="O1",
)
self._bart_distiller_cli(updates)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_t5_eval_fp16(self):
updates = dict(
fp16=FP16_EVER, fp16=FP16_EVER,
gpus=1,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
do_train=False,
do_predict=True,
tokenizer_name=None,
no_teacher=True,
) )
self._bart_distiller_cli(updates, check_contents=False) self._bart_distiller_cli(updates)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test") def test_bdc_t5_train(self):
def test_bdc_t5_train_fp16(self):
updates = dict( updates = dict(
fp16=FP16_EVER, fp16=FP16_EVER,
gpus=1, gpus=1 if torch.cuda.is_available() else 0,
model_type="t5", model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random", model_name_or_path=T5_TINY,
do_train=True, do_train=True,
do_predict=True, do_predict=True,
tokenizer_name="patrickvonplaten/t5-tiny-random", tokenizer_name=T5_TINY,
no_teacher=True, no_teacher=True,
alpha_hid=2.0,
) )
self._bart_distiller_cli(updates) self._bart_distiller_cli(updates)
...@@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
self._bart_distiller_cli(updates) self._bart_distiller_cli(updates)
def test_bdc_checkpointing(self): def test_bdc_checkpointing(self):
updates = dict( updates = dict(
student_encoder_layers=2, student_encoder_layers=2,
student_decoder_layers=1, student_decoder_layers=1,
...@@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
def test_bdc_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher="patrickvonplaten/t5-tiny-random",
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
tokenizer_name="patrickvonplaten/t5-tiny-random",
)
self._bart_distiller_cli(updates)
def test_bdc_t5_eval(self):
updates = dict(
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
do_train=False,
do_predict=True,
tokenizer_name="patrickvonplaten/t5-tiny-random",
no_teacher=True,
)
self._bart_distiller_cli(updates, check_contents=False)
def _bart_distiller_cli(self, updates, check_contents=True): def _bart_distiller_cli(self, updates, check_contents=True):
default_updates = dict( default_updates = dict(
model_type="bart",
train_batch_size=1, train_batch_size=1,
eval_batch_size=2, eval_batch_size=2,
num_train_epochs=2, num_train_epochs=2,
...@@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertIn(ckpt_name, contents) self.assertIn(ckpt_name, contents)
self.assertIn("metrics.pkl", contents) self.assertIn("metrics.pkl", contents)
self.assertIn("test_generations.txt", contents) self.assertIn("test_generations.txt", contents)
self.assertIn("val_generations_1.txt", contents) self.assertIn("val_generations_00001.txt", contents)
self.assertIn("val_1_results.txt", contents) self.assertIn("val_results_00001.txt", contents)
self.assertIn("test_results.txt", contents) self.assertIn("test_results.txt", contents)
# self.assertEqual(len(contents), 15)
metrics = pickle_load(Path(output_dir) / "metrics.pkl") metrics = pickle_load(Path(output_dir) / "metrics.pkl")
import pandas as pd desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
val_df = pd.DataFrame(metrics["val"]) self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
train_df = pd.DataFrame(metrics["train"])
test_df = pd.DataFrame(metrics["test"])
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
self.assertEqual(val_df.shape[0], desired_n_evals) #
self.assertEqual(test_df.shape[1], val_df.shape[1])
self.assertEqual(train_df.shape[0], 0)
return model return model
...@@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase): ...@@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
output_dir = tempfile.mkdtemp(prefix="output_") output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update( args_d.update(
data_dir=tmp_dir, data_dir=tmp_dir,
model_type="t5", model_name_or_path=T5_TINY,
model_name_or_path="patrickvonplaten/t5-tiny-random", tokenizer_name=None, # T5_TINY,
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
train_batch_size=2, train_batch_size=2,
eval_batch_size=2, eval_batch_size=2,
gpus=0, gpus=0,
......
...@@ -45,8 +45,10 @@ def encode_file( ...@@ -45,8 +45,10 @@ def encode_file(
max_length=max_length, max_length=max_length,
pad_to_max_length=pad_to_max_length, pad_to_max_length=pad_to_max_length,
add_prefix_space=True, add_prefix_space=True,
truncation=True,
return_tensors=return_tensors, return_tensors=return_tensors,
) )
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized) examples.append(tokenized)
torch.save(lmap(dict, examples), cache_path.open("wb")) torch.save(lmap(dict, examples), cache_path.open("wb"))
return examples return examples
......
...@@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer): ...@@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
def _eval_end(self, outputs): def _eval_end(self, outputs) -> tuple:
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item() val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
preds = np.concatenate([x["pred"] for x in outputs], axis=0) preds = np.concatenate([x["pred"] for x in outputs], axis=0)
...@@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer): ...@@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
logs = ret["log"] logs = ret["log"]
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs} return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs) -> dict:
# updating to test_epoch_end instead of deprecated test_end
ret, predictions, targets = self._eval_end(outputs) ret, predictions, targets = self._eval_end(outputs)
# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs = ret["log"] logs = ret["log"]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss` # `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs} return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
@staticmethod @staticmethod
def add_model_specific_args(parser, root_dir): def add_model_specific_args(parser, root_dir):
# Add NER specific options
BaseTransformer.add_model_specific_args(parser, root_dir) BaseTransformer.add_model_specific_args(parser, root_dir)
parser.add_argument( parser.add_argument(
"--max_seq_length", "--max_seq_length",
......
...@@ -205,7 +205,7 @@ class AutoTokenizer: ...@@ -205,7 +205,7 @@ class AutoTokenizer:
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if "bert-base-japanese" in pretrained_model_name_or_path: if "bert-base-japanese" in str(pretrained_model_name_or_path):
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
use_fast = kwargs.pop("use_fast", False) use_fast = kwargs.pop("use_fast", False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment