"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "8363ef967d9ff9305729a6bdb901f403aaa5a417"
Unverified Commit 5aa361f3 authored by Daniel Khashabi's avatar Daniel Khashabi Committed by GitHub
Browse files

finetune.py: specifying generation min_length (#8478)

parent 30e7f7e5
...@@ -113,6 +113,10 @@ class SummarizationModule(BaseTransformer): ...@@ -113,6 +113,10 @@ class SummarizationModule(BaseTransformer):
self.eval_max_length = self.hparams.eval_max_gen_length self.eval_max_length = self.hparams.eval_max_gen_length
else: else:
self.eval_max_length = self.model.config.max_length self.eval_max_length = self.model.config.max_length
if self.hparams.eval_min_gen_length is not None:
self.eval_min_length = self.hparams.eval_min_gen_length
else:
self.eval_min_length = self.model.config.min_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]: def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
...@@ -219,6 +223,7 @@ class SummarizationModule(BaseTransformer): ...@@ -219,6 +223,7 @@ class SummarizationModule(BaseTransformer):
decoder_start_token_id=self.decoder_start_token_id, decoder_start_token_id=self.decoder_start_token_id,
num_beams=self.eval_beams, num_beams=self.eval_beams,
max_length=self.eval_max_length, max_length=self.eval_max_length,
min_length=self.eval_min_length,
) )
gen_time = (time.time() - t0) / batch["input_ids"].shape[0] gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids) preds: List[str] = self.ids_to_clean_text(generated_ids)
...@@ -346,6 +351,7 @@ class SummarizationModule(BaseTransformer): ...@@ -346,6 +351,7 @@ class SummarizationModule(BaseTransformer):
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
) )
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens") parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
parser.add_argument("--eval_min_gen_length", type=int, default=None, help="never generate shorter than n tokens")
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
parser.add_argument( parser.add_argument(
"--early_stopping_patience", "--early_stopping_patience",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment