Unverified Commit d9050dc7 authored by Cesare Campagnano's avatar Cesare Campagnano Committed by GitHub
Browse files

[LED] fix global_attention_mask not being passed for generation and docs...


[LED] fix global_attention_mask not being passed for generation and docs clarification about grad checkpointing (#17112)

* [LED] fixed global_attention_mask not passed for generation + docs clarification for gradient checkpointing

* LED docs clarification
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* [LED] gradient_checkpointing=True should be passed to TrainingArguments
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* [LED] docs: remove wrong word
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* [LED] docs fix typo
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent bad35839
......@@ -44,8 +44,10 @@ Tips:
- LED makes use of *global attention* by means of the `global_attention_mask` (see
[`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first
`<s>` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
`model.gradient_checkpointing_enable()`.
- To fine-tune LED on all 16384, *gradient checkpointing* can be enabled in case training leads to out-of-memory (OOM)
errors. This can be done by executing `model.gradient_checkpointing_enable()`.
Moreover, the `use_cache=False`
flag can be used to disable the caching mechanism to save memory.
- A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing).
- A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing).
......
......@@ -2441,6 +2441,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
decoder_input_ids,
past=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
......@@ -2458,6 +2459,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"global_attention_mask": global_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment