Unverified Commit b33ab4eb authored by John Giorgi's avatar John Giorgi Committed by GitHub
Browse files

Add global_attention_mask to gen_kwargs (#16485)

If global_attention_mask is found in the models inputs (used by certain
models, like LED) in the prediction_step method of Seq2SeqTrainer,
it is added to the gen_kwargs, which are passed to model.decode().
This allows us to properly set the global attention when decoding.
parent 9fd5e6bb
......@@ -163,9 +163,11 @@ class Seq2SeqTrainer(Trainer):
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
if "global_attention_mask" in inputs:
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
# prepare generation inputs
# some encoder-decoder models can have varying encder's and thus
# some encoder-decoder models can have varying encoder's and thus
# varying model input names
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
generation_inputs = inputs[self.model.encoder.main_input_name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment