Unverified Commit 59c1d5b9 authored by Yi Heng Lim's avatar Yi Heng Lim Committed by GitHub
Browse files

[GPT2, ProphetNet] Fix gradient checkpointing bug (#21772)

* fix gradient checkpointing bug

* fix gradient checkpointing bug

* ran make fix-copies

* fixed bug

* fixed bug
parent ba0e370d
......@@ -607,6 +607,13 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
......@@ -627,11 +634,6 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -851,6 +851,13 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
......@@ -871,11 +878,6 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -1569,6 +1569,14 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
all_main_stream_attns = () if output_attentions else None
all_ngram_stream_attns = () if output_attentions else None
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
present_key_values = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
......@@ -1588,11 +1596,6 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -1592,6 +1592,14 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
all_main_stream_attns = () if output_attentions else None
all_ngram_stream_attns = () if output_attentions else None
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
present_key_values = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
......@@ -1611,11 +1619,6 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment