Unverified Commit 76116f47 authored by CeShine Lee's avatar CeShine Lee Committed by GitHub
Browse files

T5 Gradient Checkpointing (#11353)

* Implement gradient checkpoinging for T5Stack

* A bit more robust type checking

* Add `gradient_checkpointing` to T5Config

* Formatting

* Set requires_grad only when training

* None return value will only cause problems when training

* Change the output tuple according to `use_cache`

* Enable gradient checkpointing for the decoder

Squashed commit of the following:

commit 658bdd0bd1215353a8770f558bda2ea69a0ad0c7
Author: Ceshine Lee <shuanck@gmail.com>
Date:   Sat Apr 24 14:08:17 2021 +0800

    Only set `require_grad` for gradient checkpointing

commit acaeee6b2e675045fb28ce2176444c1d63e908bd
Author: Ceshine Lee <shuanck@gmail.com>
Date:   Sat Apr 24 13:59:35 2021 +0800

    Make gradient checkpointing work with the decoder

* Formatting
parent 58c789e3
......@@ -71,6 +71,8 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
......@@ -93,6 +95,7 @@ class T5Config(PretrainedConfig):
use_cache=True,
pad_token_id=0,
eos_token_id=1,
gradient_checkpointing=False,
**kwargs
):
super().__init__(
......@@ -116,6 +119,7 @@ class T5Config(PretrainedConfig):
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
self.gradient_checkpointing = gradient_checkpointing
@property
def hidden_size(self):
......
......@@ -24,6 +24,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN
from ...file_utils import (
......@@ -323,6 +324,7 @@ class T5Attention(nn.Module):
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
def prune_heads(self, heads):
if len(heads) == 0:
......@@ -485,6 +487,8 @@ class T5Attention(nn.Module):
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.training and self.gradient_checkpointing:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length)
......@@ -691,7 +695,11 @@ class T5Block(nn.Module):
outputs = (hidden_states,)
outputs = outputs + (present_key_value_state,) + attention_outputs
if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
......@@ -947,21 +955,51 @@ class T5Stack(T5PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment