"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "671b278e252050efa6f5049ebc783d51670b3b4e"
Unverified Commit 76116f47 authored by CeShine Lee's avatar CeShine Lee Committed by GitHub
Browse files

T5 Gradient Checkpointing (#11353)

* Implement gradient checkpoinging for T5Stack

* A bit more robust type checking

* Add `gradient_checkpointing` to T5Config

* Formatting

* Set requires_grad only when training

* None return value will only cause problems when training

* Change the output tuple according to `use_cache`

* Enable gradient checkpointing for the decoder

Squashed commit of the following:

commit 658bdd0bd1215353a8770f558bda2ea69a0ad0c7
Author: Ceshine Lee <shuanck@gmail.com>
Date:   Sat Apr 24 14:08:17 2021 +0800

    Only set `require_grad` for gradient checkpointing

commit acaeee6b2e675045fb28ce2176444c1d63e908bd
Author: Ceshine Lee <shuanck@gmail.com>
Date:   Sat Apr 24 13:59:35 2021 +0800

    Make gradient checkpointing work with the decoder

* Formatting
parent 58c789e3
...@@ -71,6 +71,8 @@ class T5Config(PretrainedConfig): ...@@ -71,6 +71,8 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
""" """
model_type = "t5" model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -93,6 +95,7 @@ class T5Config(PretrainedConfig): ...@@ -93,6 +95,7 @@ class T5Config(PretrainedConfig):
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
eos_token_id=1, eos_token_id=1,
gradient_checkpointing=False,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -116,6 +119,7 @@ class T5Config(PretrainedConfig): ...@@ -116,6 +119,7 @@ class T5Config(PretrainedConfig):
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache self.use_cache = use_cache
self.gradient_checkpointing = gradient_checkpointing
@property @property
def hidden_size(self): def hidden_size(self):
......
...@@ -24,6 +24,7 @@ import torch ...@@ -24,6 +24,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -323,6 +324,7 @@ class T5Attention(nn.Module): ...@@ -323,6 +324,7 @@ class T5Attention(nn.Module):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set() self.pruned_heads = set()
self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
...@@ -485,6 +487,8 @@ class T5Attention(nn.Module): ...@@ -485,6 +487,8 @@ class T5Attention(nn.Module):
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.training and self.gradient_checkpointing:
position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length) position_bias = self.compute_bias(real_seq_length, key_length)
...@@ -691,7 +695,11 @@ class T5Block(nn.Module): ...@@ -691,7 +695,11 @@ class T5Block(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
outputs = outputs + (present_key_value_state,) + attention_outputs if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
...@@ -947,21 +955,51 @@ class T5Stack(T5PreTrainedModel): ...@@ -947,21 +955,51 @@ class T5Stack(T5PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( if getattr(self.config, "gradient_checkpointing", False) and self.training:
hidden_states, if use_cache:
attention_mask=extended_attention_mask, logger.warn(
position_bias=position_bias, "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
encoder_hidden_states=encoder_hidden_states, "`use_cache=False`..."
encoder_attention_mask=encoder_extended_attention_mask, )
encoder_decoder_position_bias=encoder_decoder_position_bias, use_cache = False
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, def create_custom_forward(module):
past_key_value=past_key_value, def custom_forward(*inputs):
use_cache=use_cache, return tuple(module(*inputs, use_cache, output_attentions))
output_attentions=output_attentions,
) return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, present_key_value_state = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment