"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c468a87a69754b865575f115a2c8efa024c82cb5"
Unverified Commit 59c1d5b9 authored by Yi Heng Lim's avatar Yi Heng Lim Committed by GitHub
Browse files

[GPT2, ProphetNet] Fix gradient checkpointing bug (#21772)

* fix gradient checkpointing bug

* fix gradient checkpointing bug

* ran make fix-copies

* fixed bug

* fixed bug
parent ba0e370d
...@@ -607,6 +607,13 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ...@@ -607,6 +607,13 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None presents = () if use_cache else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -627,11 +634,6 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ...@@ -627,11 +634,6 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -851,6 +851,13 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -851,6 +851,13 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None presents = () if use_cache else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -871,11 +878,6 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -871,11 +878,6 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1569,6 +1569,14 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1569,6 +1569,14 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
all_main_stream_attns = () if output_attentions else None all_main_stream_attns = () if output_attentions else None
all_ngram_stream_attns = () if output_attentions else None all_ngram_stream_attns = () if output_attentions else None
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
...@@ -1588,11 +1596,6 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1588,11 +1596,6 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1592,6 +1592,14 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1592,6 +1592,14 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
all_main_stream_attns = () if output_attentions else None all_main_stream_attns = () if output_attentions else None
all_ngram_stream_attns = () if output_attentions else None all_ngram_stream_attns = () if output_attentions else None
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
...@@ -1611,11 +1619,6 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1611,11 +1619,6 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment