Unverified Commit 2a7402cb authored by Pradhy729's avatar Pradhy729 Committed by GitHub
Browse files

Feed forward chunking others (#6365)



* Feed forward chunking for Distilbert & Albert

* Added ff chunking for many other models

* Change model signature

* Added chunking for XLM

* Cleaned up by removing some variables.

* remove test_chunking flag
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent fe0b85e7
File mode changed from 100644 to 100755
...@@ -191,6 +191,7 @@ class PretrainedConfig(object): ...@@ -191,6 +191,7 @@ class PretrainedConfig(object):
self.pad_token_id = kwargs.pop("pad_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None) self.eos_token_id = kwargs.pop("eos_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forwar", 0)
# task specific arguments # task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None) self.task_specific_params = kwargs.pop("task_specific_params", None)
......
...@@ -43,7 +43,7 @@ from .modeling_outputs import ( ...@@ -43,7 +43,7 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -69,6 +69,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -69,6 +69,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.""" """ Load tf checkpoints in a pytorch model."""
try: try:
import re import re
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:
...@@ -286,6 +287,8 @@ class AlbertLayer(nn.Module): ...@@ -286,6 +287,8 @@ class AlbertLayer(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config) self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
...@@ -297,14 +300,20 @@ class AlbertLayer(nn.Module): ...@@ -297,14 +300,20 @@ class AlbertLayer(nn.Module):
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
): ):
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output) ffn_output = apply_chunking_to_forward(
ffn_output = self.ffn_output(ffn_output) self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0],
ffn_output = self.dropout(ffn_output) )
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
return (hidden_states,) + attention_output[1:] # add attentions if we output them return (hidden_states,) + attention_output[1:] # add attentions if we output them
def ff_chunk(self, attention_output):
ffn_output = self.ffn(attention_output)
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
return ffn_output
class AlbertLayerGroup(nn.Module): class AlbertLayerGroup(nn.Module):
def __init__(self, config): def __init__(self, config):
......
...@@ -424,7 +424,7 @@ class BertLayer(nn.Module): ...@@ -424,7 +424,7 @@ class BertLayer(nn.Module):
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
return outputs return outputs
......
...@@ -44,7 +44,12 @@ from .modeling_outputs import ( ...@@ -44,7 +44,12 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -208,6 +213,8 @@ class FFN(nn.Module): ...@@ -208,6 +213,8 @@ class FFN(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dropout = nn.Dropout(p=config.dropout) self.dropout = nn.Dropout(p=config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format( assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
...@@ -216,6 +223,9 @@ class FFN(nn.Module): ...@@ -216,6 +223,9 @@ class FFN(nn.Module):
self.activation = gelu if config.activation == "gelu" else nn.ReLU() self.activation = gelu if config.activation == "gelu" else nn.ReLU()
def forward(self, input): def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
def ff_chunk(self, input):
x = self.lin1(input) x = self.lin1(input)
x = self.activation(x) x = self.activation(x)
x = self.lin2(x) x = self.lin2(x)
......
...@@ -41,7 +41,12 @@ from .modeling_outputs import ( ...@@ -41,7 +41,12 @@ from .modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -685,6 +690,8 @@ class LongformerLayer(nn.Module): ...@@ -685,6 +690,8 @@ class LongformerLayer(nn.Module):
self.attention = LongformerAttention(config, layer_id) self.attention = LongformerAttention(config, layer_id)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
def forward( def forward(
self, hidden_states, attention_mask=None, output_attentions=False, self, hidden_states, attention_mask=None, output_attentions=False,
...@@ -693,11 +700,17 @@ class LongformerLayer(nn.Module): ...@@ -693,11 +700,17 @@ class LongformerLayer(nn.Module):
attn_output = self_attn_outputs[0] attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
intermediate_output = self.intermediate(attn_output) layer_output = apply_chunking_to_forward(
layer_output = self.output(intermediate_output, attn_output) self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output
)
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
return outputs return outputs
def ff_chunk(self, attn_output):
intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
return layer_output
class LongformerEncoder(nn.Module): class LongformerEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
......
...@@ -1369,7 +1369,7 @@ class ChunkReformerFeedForward(nn.Module): ...@@ -1369,7 +1369,7 @@ class ChunkReformerFeedForward(nn.Module):
def forward(self, attention_output): def forward(self, attention_output):
return apply_chunking_to_forward( return apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output, self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output,
) )
def forward_chunk(self, hidden_states): def forward_chunk(self, hidden_states):
...@@ -1730,7 +1730,7 @@ class ReformerOnlyLMHead(nn.Module): ...@@ -1730,7 +1730,7 @@ class ReformerOnlyLMHead(nn.Module):
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
def forward_chunk(self, hidden_states): def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states) hidden_states = self.decoder(hidden_states)
......
...@@ -1519,7 +1519,7 @@ def prune_layer( ...@@ -1519,7 +1519,7 @@ def prune_layer(
def apply_chunking_to_forward( def apply_chunking_to_forward(
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
...@@ -1529,12 +1529,12 @@ def apply_chunking_to_forward( ...@@ -1529,12 +1529,12 @@ def apply_chunking_to_forward(
directly applying :obj:`forward_fn` to :obj:`input_tensors`. directly applying :obj:`forward_fn` to :obj:`input_tensors`.
Args: Args:
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
chunk_size (:obj:`int`): chunk_size (:obj:`int`):
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`. The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (:obj:`int`): chunk_dim (:obj:`int`):
The dimension over which the :obj:`input_tensors` should be chunked. The dimension over which the :obj:`input_tensors` should be chunked.
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
input_tensors (:obj:`Tuple[torch.Tensor]`): input_tensors (:obj:`Tuple[torch.Tensor]`):
The input tensors of ``forward_fn`` which will be chunked. The input tensors of ``forward_fn`` which will be chunked.
Returns: Returns:
...@@ -1550,7 +1550,7 @@ def apply_chunking_to_forward( ...@@ -1550,7 +1550,7 @@ def apply_chunking_to_forward(
# implement a chunked forward function # implement a chunked forward function
def forward(self, hidden_states): def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
""" """
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
......
...@@ -50,6 +50,7 @@ from .modeling_utils import ( ...@@ -50,6 +50,7 @@ from .modeling_utils import (
PreTrainedModel, PreTrainedModel,
SequenceSummary, SequenceSummary,
SQuADHead, SQuADHead,
apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
...@@ -212,8 +213,13 @@ class TransformerFFN(nn.Module): ...@@ -212,8 +213,13 @@ class TransformerFFN(nn.Module):
self.lin1 = nn.Linear(in_dim, dim_hidden) self.lin1 = nn.Linear(in_dim, dim_hidden)
self.lin2 = nn.Linear(dim_hidden, out_dim) self.lin2 = nn.Linear(dim_hidden, out_dim)
self.act = gelu if config.gelu_activation else F.relu self.act = gelu if config.gelu_activation else F.relu
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
def forward(self, input): def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
def ff_chunk(self, input):
x = self.lin1(input) x = self.lin1(input)
x = self.act(x) x = self.act(x)
x = self.lin2(x) x = self.lin2(x)
......
...@@ -35,7 +35,14 @@ from .file_utils import ( ...@@ -35,7 +35,14 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary from .modeling_utils import (
PoolerAnswerClass,
PoolerEndLogits,
PoolerStartLogits,
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -495,6 +502,8 @@ class XLNetLayer(nn.Module): ...@@ -495,6 +502,8 @@ class XLNetLayer(nn.Module):
self.rel_attn = XLNetRelativeAttention(config) self.rel_attn = XLNetRelativeAttention(config)
self.ff = XLNetFeedForward(config) self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
def forward( def forward(
self, self,
...@@ -524,12 +533,18 @@ class XLNetLayer(nn.Module): ...@@ -524,12 +533,18 @@ class XLNetLayer(nn.Module):
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if output_g is not None: if output_g is not None:
output_g = self.ff(output_g) output_g = apply_chunking_to_forward(
output_h = self.ff(output_h) self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g
)
output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h)
outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
return outputs return outputs
def ff_chunk(self, output_x):
output_x = self.ff(output_x)
return output_x
class XLNetPreTrainedModel(PreTrainedModel): class XLNetPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
......
...@@ -26,15 +26,15 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor ...@@ -26,15 +26,15 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
from transformers import ( from transformers import (
BertConfig, BertConfig,
BertModel,
BertLMHeadModel,
BertForMaskedLM, BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForPreTraining, BertForPreTraining,
BertForQuestionAnswering, BertForQuestionAnswering,
BertForSequenceClassification, BertForSequenceClassification,
BertForTokenClassification, BertForTokenClassification,
BertForMultipleChoice, BertLMHeadModel,
BertModel,
) )
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -370,7 +370,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -370,7 +370,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
test_chunking = True
def setUp(self): def setUp(self):
self.model_tester = BertModelTester(self) self.model_tester = BertModelTester(self)
......
...@@ -25,15 +25,15 @@ from transformers.testing_utils import require_multigpu, require_torch, slow, to ...@@ -25,15 +25,15 @@ from transformers.testing_utils import require_multigpu, require_torch, slow, to
if is_torch_available(): if is_torch_available():
import torch
import numpy as np import numpy as np
import torch
from transformers import ( from transformers import (
AdaptiveEmbedding, AdaptiveEmbedding,
PretrainedConfig, PretrainedConfig,
PreTrainedModel, PreTrainedModel,
BertModel,
BertConfig, BertConfig,
BertModel,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
...@@ -65,7 +65,6 @@ class ModelTesterMixin: ...@@ -65,7 +65,6 @@ class ModelTesterMixin:
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = True test_head_masking = True
test_missing_keys = True test_missing_keys = True
test_chunking = False
is_encoder_decoder = False is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
...@@ -552,9 +551,6 @@ class ModelTesterMixin: ...@@ -552,9 +551,6 @@ class ModelTesterMixin:
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common() (original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_chunking:
return
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
torch.manual_seed(0) torch.manual_seed(0)
config = copy.deepcopy(original_config) config = copy.deepcopy(original_config)
......
...@@ -555,7 +555,6 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest ...@@ -555,7 +555,6 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
test_chunking = True
def prepare_kwargs(self): def prepare_kwargs(self):
return { return {
...@@ -616,7 +615,6 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T ...@@ -616,7 +615,6 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
test_chunking = True
def prepare_kwargs(self): def prepare_kwargs(self):
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment