Unverified Commit 64e3d966 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add support for past states (#5399)

* Add support for past states

* Style and forgotten self

* You mean, documenting is not enough? I have to actually add it too?

* Add memory support during evaluation

* Fix tests in eval and add TF support

* No need to change this line anymore
parent 4ade7491
...@@ -493,6 +493,10 @@ class Trainer: ...@@ -493,6 +493,10 @@ class Trainer:
else: else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
...@@ -575,6 +579,9 @@ class Trainer: ...@@ -575,6 +579,9 @@ class Trainer:
if self.tb_writer: if self.tb_writer:
self.tb_writer.close() self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step) return TrainOutput(self.global_step, tr_loss / self.global_step)
...@@ -617,9 +624,15 @@ class Trainer: ...@@ -617,9 +624,15 @@ class Trainer:
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device) inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc) loss = outputs[0] # model outputs are always tuple in transformers (see doc)
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1: if self.args.gradient_accumulation_steps > 1:
...@@ -802,12 +815,17 @@ class Trainer: ...@@ -802,12 +815,17 @@ class Trainer:
if is_torch_tpu_available(): if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0:
past = None
for inputs in tqdm(dataloader, desc=description): for inputs in tqdm(dataloader, desc=description):
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device) inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0:
inputs["mems"] = past
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
...@@ -816,6 +834,8 @@ class Trainer: ...@@ -816,6 +834,8 @@ class Trainer:
eval_losses += [step_eval_loss.mean().item()] eval_losses += [step_eval_loss.mean().item()]
else: else:
logits = outputs[0] logits = outputs[0]
if self.args.past_index >= 0:
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
if not prediction_loss_only: if not prediction_loss_only:
if preds is None: if preds is None:
......
...@@ -240,6 +240,10 @@ class TFTrainer: ...@@ -240,6 +240,10 @@ class TFTrainer:
step: int = 1 step: int = 1
# Reset the past mems state at the beginning of the evaluation if necessary.
if self.args.past_index >= 0:
self._past = None
for features, labels in dataset: for features, labels in dataset:
step = tf.convert_to_tensor(step, dtype=tf.int64) step = tf.convert_to_tensor(step, dtype=tf.int64)
loss, logits = self._evaluate_steps(features, labels) loss, logits = self._evaluate_steps(features, labels)
...@@ -288,6 +292,10 @@ class TFTrainer: ...@@ -288,6 +292,10 @@ class TFTrainer:
if not key.startswith("eval_"): if not key.startswith("eval_"):
metrics[f"eval_{key}"] = metrics.pop(key) metrics[f"eval_{key}"] = metrics.pop(key)
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _log(self, logs: Dict[str, float]) -> None: def _log(self, logs: Dict[str, float]) -> None:
...@@ -405,6 +413,9 @@ class TFTrainer: ...@@ -405,6 +413,9 @@ class TFTrainer:
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
for epoch_iter in range(epochs_trained, int(epochs + 1)): for epoch_iter in range(epochs_trained, int(epochs + 1)):
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)): for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
self.global_step = iterations.numpy() self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
...@@ -444,6 +455,10 @@ class TFTrainer: ...@@ -444,6 +455,10 @@ class TFTrainer:
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0: if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
break break
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
def _training_steps(self, ds, optimizer): def _training_steps(self, ds, optimizer):
""" """
Returns a generator over training steps (i.e. parameters update). Returns a generator over training steps (i.e. parameters update).
...@@ -518,10 +533,15 @@ class TFTrainer: ...@@ -518,10 +533,15 @@ class TFTrainer:
labels: the batched labels. labels: the batched labels.
training: run the model in training mode or not training: run the model in training mode or not
""" """
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
features["mems"] = self._past
if isinstance(labels, (dict)): if isinstance(labels, (dict)):
loss, logits = self.model(features, training=training, **labels)[:2] outputs = self.model(features, training=training, **labels)[:2]
else: else:
loss, logits = self.model(features, labels=labels, training=training)[:2] outputs = self.model(features, labels=labels, training=training)[:2]
loss, logits = outputs[:2]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu) loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
return loss, logits return loss, logits
......
...@@ -102,6 +102,11 @@ class TrainingArguments: ...@@ -102,6 +102,11 @@ class TrainingArguments:
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -203,6 +208,11 @@ class TrainingArguments: ...@@ -203,6 +208,11 @@ class TrainingArguments:
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
) )
past_index: int = field(
default=-1,
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
)
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
""" """
......
...@@ -85,6 +85,11 @@ class TFTrainingArguments(TrainingArguments): ...@@ -85,6 +85,11 @@ class TFTrainingArguments(TrainingArguments):
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`): tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on. The name of the TPU the process is running on.
eval_steps (:obj:`int`, `optional`, defaults to 1000): eval_steps (:obj:`int`, `optional`, defaults to 1000):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment