Unverified Commit 31fa2b6c authored by Herumb Shandilya's avatar Herumb Shandilya Committed by GitHub
Browse files

[GPTJ] Fix gradient checkpointing bug (#21794)



* If applied, this commit fixes generate bug in gptj

* Remove extra same code block

* formatting and test fix

* Conflict fix and declaration error fix

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent eec76042
......@@ -633,6 +633,13 @@ class GPTJModel(GPTJPreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
......@@ -652,11 +659,6 @@ class GPTJModel(GPTJPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment