"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8a618e0af5140a91e237f7758657a78319b03a2e"
Unverified Commit 1a77a1a8 authored by Nipun Jindal's avatar Nipun Jindal Committed by GitHub
Browse files

[21737][T5]: Fix gradient checkpoint bug (#22036)



* [21737][T5]: Fix gradient checkpoint bug

* [21737][T5]: Fix gradient checkpoint bug

* [21737][T5]: Fix gradient checkpoint bug

* Update src/transformers/models/mt5/modeling_mt5.py

* Update src/transformers/models/t5/modeling_t5.py

---------
Co-authored-by: default avatarnjindal <njindal@adobe.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 2055d737
...@@ -978,6 +978,13 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -978,6 +978,13 @@ class MT5Stack(MT5PreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
...@@ -1015,11 +1022,6 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -1015,11 +1022,6 @@ class MT5Stack(MT5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1007,6 +1007,13 @@ class T5Stack(T5PreTrainedModel): ...@@ -1007,6 +1007,13 @@ class T5Stack(T5PreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
...@@ -1044,11 +1051,6 @@ class T5Stack(T5PreTrainedModel): ...@@ -1044,11 +1051,6 @@ class T5Stack(T5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment