Unverified Commit 13570381 authored by Eran Hirsch's avatar Eran Hirsch Committed by GitHub
Browse files

Add logits_processor parameter, used by `generate`, to `Seq2SeqTrainer`...

Add logits_processor parameter, used by `generate`, to `Seq2SeqTrainer` methods `evaluate` and `predict` (#17805)

* Add logits_processor parameter, used by `generate`, to `Seq2SeqTrainer` methods `evaluate` and `predict`

* Add all generate parameters to `Seq2SeqTrainer`, and also to `QuestionAnsweringSeq2SeqTrainer` which overrides it

* Remove `self._num_beams` from trainer classes

* - Run fixup
- Fix "Constraint" not exposed
- Fix synced_gpus to actually read from param

* Use kwargs

* Copy kwargs before making changes to it

* Fix style issues unused imports
parent 16c6eb7c
...@@ -41,11 +41,16 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -41,11 +41,16 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
eval_examples=None, eval_examples=None,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval", metric_key_prefix: str = "eval",
max_length: Optional[int] = None, **gen_kwargs,
num_beams: Optional[int] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
self._max_length = max_length if max_length is not None else self.args.generation_max_length gen_kwargs = gen_kwargs.copy()
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_dataloader = self.get_eval_dataloader(eval_dataset)
...@@ -87,7 +92,11 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -87,7 +92,11 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics return metrics
def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): def predict(
self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test", **gen_kwargs
):
self._gen_kwargs = gen_kwargs.copy()
predict_dataloader = self.get_test_dataloader(predict_dataset) predict_dataloader = self.get_test_dataloader(predict_dataset)
# Temporarily disable metric computation, we will do it in the loop here. # Temporarily disable metric computation, we will do it in the loop here.
......
...@@ -33,8 +33,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -33,8 +33,7 @@ class Seq2SeqTrainer(Trainer):
eval_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval", metric_key_prefix: str = "eval",
max_length: Optional[int] = None, **gen_kwargs
num_beams: Optional[int] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
Run evaluation and returns metrics. Run evaluation and returns metrics.
...@@ -60,13 +59,23 @@ class Seq2SeqTrainer(Trainer): ...@@ -60,13 +59,23 @@ class Seq2SeqTrainer(Trainer):
num_beams (`int`, *optional*): num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search. beam search.
gen_kwargs:
Additional `generate` specific kwargs.
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state. dictionary also contains the epoch number which comes from the training state.
""" """
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict( def predict(
...@@ -74,8 +83,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -74,8 +83,7 @@ class Seq2SeqTrainer(Trainer):
test_dataset: Dataset, test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test", metric_key_prefix: str = "test",
max_length: Optional[int] = None, **gen_kwargs
num_beams: Optional[int] = None,
) -> PredictionOutput: ) -> PredictionOutput:
""" """
Run prediction and returns predictions and potential metrics. Run prediction and returns predictions and potential metrics.
...@@ -98,6 +106,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -98,6 +106,8 @@ class Seq2SeqTrainer(Trainer):
num_beams (`int`, *optional*): num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search. beam search.
gen_kwargs:
Additional `generate` specific kwargs.
<Tip> <Tip>
...@@ -114,8 +124,16 @@ class Seq2SeqTrainer(Trainer): ...@@ -114,8 +124,16 @@ class Seq2SeqTrainer(Trainer):
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels). labels).
""" """
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step( def prediction_step(
...@@ -155,11 +173,17 @@ class Seq2SeqTrainer(Trainer): ...@@ -155,11 +173,17 @@ class Seq2SeqTrainer(Trainer):
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well # XXX: adapt synced_gpus for fairscale as well
gen_kwargs = { gen_kwargs = self._gen_kwargs.copy()
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length, gen_kwargs["max_length"] = (
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length
"synced_gpus": True if is_deepspeed_zero3_enabled() else False, )
} gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
)
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
gen_kwargs["synced_gpus"] = (
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)
if "attention_mask" in inputs: if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment