Unverified Commit 485bbe79 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax] Add remat (gradient checkpointing) (#17843)

* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* make fix-copies

* fix big-bird, electra, roberta

* cookie-cutter

* fix flax big-bird

* move test to common
parent 664688b9
......@@ -235,6 +235,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")
def enable_gradient_checkpointing(self):
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
@classmethod
def _from_config(cls, config, **kwargs):
"""
......
......@@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
......@@ -56,6 +57,8 @@ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
remat = nn_partitioning.remat
@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
......@@ -544,8 +547,16 @@ class FlaxBertLayer(nn.Module):
class FlaxBertLayerCollection(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
self.layers = [
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
......@@ -582,12 +593,12 @@ class FlaxBertLayerCollection(nn.Module):
layer_outputs = layer(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)
hidden_states = layer_outputs[0]
......@@ -617,9 +628,14 @@ class FlaxBertLayerCollection(nn.Module):
class FlaxBertEncoder(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
self.layer = FlaxBertLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
......@@ -756,11 +772,24 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(
config=config,
dtype=dtype,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
......@@ -925,10 +954,15 @@ class FlaxBertModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
def __call__(
......@@ -1003,9 +1037,14 @@ append_call_sample_docstring(
class FlaxBertForPreTrainingModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1099,9 +1138,15 @@ append_replace_return_docstrings(
class FlaxBertForMaskedLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1161,9 +1206,14 @@ append_call_sample_docstring(
class FlaxBertForNextSentencePredictionModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
def __call__(
......@@ -1248,9 +1298,14 @@ append_replace_return_docstrings(
class FlaxBertForSequenceClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
......@@ -1324,9 +1379,14 @@ append_call_sample_docstring(
class FlaxBertForMultipleChoiceModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
......@@ -1399,9 +1459,15 @@ append_call_sample_docstring(
class FlaxBertForTokenClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
......@@ -1468,9 +1534,15 @@ append_call_sample_docstring(
class FlaxBertForQuestionAnsweringModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
......@@ -1539,9 +1611,15 @@ append_call_sample_docstring(
class FlaxBertForCausalLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
......
......@@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
......@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base"
_CONFIG_FOR_DOC = "BigBirdConfig"
_TOKENIZER_FOR_DOC = "BigBirdTokenizer"
remat = nn_partitioning.remat
@flax.struct.dataclass
class FlaxBigBirdForPreTrainingOutput(ModelOutput):
......@@ -1368,8 +1371,16 @@ class FlaxBigBirdLayer(nn.Module):
class FlaxBigBirdLayerCollection(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7))
self.layers = [
FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
......@@ -1408,12 +1419,12 @@ class FlaxBigBirdLayerCollection(nn.Module):
layer_outputs = layer(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)
hidden_states = layer_outputs[0]
......@@ -1444,9 +1455,14 @@ class FlaxBigBirdLayerCollection(nn.Module):
class FlaxBigBirdEncoder(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layer = FlaxBigBirdLayerCollection(self.config, dtype=self.dtype)
self.layer = FlaxBigBirdLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
......@@ -1559,9 +1575,10 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
if config.attention_type == "block_sparse" and input_shape is None:
input_shape = (1, 12 * config.block_size)
elif input_shape is None:
......@@ -1569,6 +1586,14 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
......@@ -1735,10 +1760,13 @@ class FlaxBigBirdModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self):
self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype)
self.encoder = FlaxBigBirdEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.pooler = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
......@@ -1812,9 +1840,14 @@ append_call_sample_docstring(
class FlaxBigBirdForPreTrainingModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBigBirdModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1910,9 +1943,15 @@ append_replace_return_docstrings(
class FlaxBigBirdForMaskedLMModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBigBirdModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1999,9 +2038,12 @@ class FlaxBigBirdClassificationHead(nn.Module):
class FlaxBigBirdForSequenceClassificationModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBigBirdModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype)
def __call__(
......@@ -2067,9 +2109,14 @@ append_call_sample_docstring(
class FlaxBigBirdForMultipleChoiceModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBigBirdModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
......@@ -2162,9 +2209,15 @@ append_call_sample_docstring(
class FlaxBigBirdForTokenClassificationModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.bert = FlaxBigBirdModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
......@@ -2255,10 +2308,16 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
add_pooling_layer: bool = False
gradient_checkpointing: bool = False
def setup(self):
self.config.num_labels = 2
self.bert = FlaxBigBirdModule(self.config, dtype=self.dtype, add_pooling_layer=self.add_pooling_layer)
self.bert = FlaxBigBirdModule(
self.config,
dtype=self.dtype,
add_pooling_layer=self.add_pooling_layer,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype)
def __call__(
......@@ -2414,9 +2473,15 @@ append_call_sample_docstring(
class FlaxBigBirdForCausalLMModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBigBirdModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
......
......@@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
......@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
_CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
remat = nn_partitioning.remat
@flax.struct.dataclass
class FlaxElectraForPreTrainingOutput(ModelOutput):
......@@ -521,10 +524,19 @@ class FlaxElectraLayer(nn.Module):
class FlaxElectraLayerCollection(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
self.layers = [
FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__(
......@@ -559,12 +571,12 @@ class FlaxElectraLayerCollection(nn.Module):
layer_outputs = layer(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)
hidden_states = layer_outputs[0]
......@@ -595,9 +607,14 @@ class FlaxElectraLayerCollection(nn.Module):
class FlaxElectraEncoder(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype)
self.layer = FlaxElectraLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
......@@ -675,11 +692,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
......@@ -845,12 +871,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
class FlaxElectraModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
if self.config.embedding_size != self.config.hidden_size:
self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype)
self.encoder = FlaxElectraEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
def __call__(
self,
......@@ -925,9 +954,12 @@ class FlaxElectraTiedDense(nn.Module):
class FlaxElectraForMaskedLMModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
......@@ -989,9 +1021,12 @@ append_call_sample_docstring(
class FlaxElectraForPreTrainingModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1074,9 +1109,12 @@ append_replace_return_docstrings(
class FlaxElectraForTokenClassificationModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
......@@ -1218,9 +1256,12 @@ class FlaxElectraSequenceSummary(nn.Module):
class FlaxElectraForMultipleChoiceModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
self.classifier = nn.Dense(1, dtype=self.dtype)
......@@ -1297,9 +1338,12 @@ append_call_sample_docstring(
class FlaxElectraForQuestionAnsweringModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
......@@ -1392,9 +1436,12 @@ class FlaxElectraClassificationHead(nn.Module):
class FlaxElectraForSequenceClassificationModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1457,9 +1504,12 @@ append_call_sample_docstring(
class FlaxElectraForCausalLMModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
......
......@@ -21,6 +21,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
......@@ -47,6 +48,8 @@ _CHECKPOINT_FOR_DOC = "roberta-base"
_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
remat = nn_partitioning.remat
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
......@@ -511,10 +514,19 @@ class FlaxRobertaLayer(nn.Module):
class FlaxRobertaLayerCollection(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
self.layers = [
FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__(
......@@ -549,12 +561,12 @@ class FlaxRobertaLayerCollection(nn.Module):
layer_outputs = layer(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)
hidden_states = layer_outputs[0]
......@@ -585,9 +597,14 @@ class FlaxRobertaLayerCollection(nn.Module):
class FlaxRobertaEncoder(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
self.layer = FlaxRobertaLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
......@@ -719,11 +736,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
......@@ -889,10 +915,15 @@ class FlaxRobertaModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self):
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(
......@@ -967,9 +998,15 @@ append_call_sample_docstring(
class FlaxRobertaForMaskedLMModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.roberta = FlaxRobertaModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1034,9 +1071,15 @@ append_call_sample_docstring(
class FlaxRobertaForSequenceClassificationModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1101,9 +1144,14 @@ append_call_sample_docstring(
class FlaxRobertaForMultipleChoiceModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype)
self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
......@@ -1181,9 +1229,15 @@ append_call_sample_docstring(
class FlaxRobertaForTokenClassificationModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
......@@ -1255,9 +1309,15 @@ append_call_sample_docstring(
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
......@@ -1326,9 +1386,15 @@ append_call_sample_docstring(
class FlaxRobertaForCausalLMModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.roberta = FlaxRobertaModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__(
......
......@@ -25,6 +25,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
......@@ -126,6 +127,8 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
"""
remat = nn_partitioning.remat
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
......@@ -507,8 +510,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7))
self.layers = [
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
......@@ -545,12 +556,12 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
layer_outputs = layer(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)
hidden_states = layer_outputs[0]
......@@ -581,9 +592,10 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype)
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
def __call__(
self,
......@@ -725,11 +737,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
......@@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self):
self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype)
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype)
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype)
def __call__(
......@@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase
class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1030,9 +1053,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
......@@ -1092,9 +1116,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(
self.config.num_labels,
......@@ -1163,9 +1188,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
......@@ -1238,9 +1264,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
......@@ -1302,9 +1329,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
......@@ -1373,9 +1401,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
......
......@@ -1099,6 +1099,33 @@ class FlaxModelTesterMixin:
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
remat_model = model_class(config)
try:
remat_model.enable_gradient_checkpointing()
except NotImplementedError:
continue
outputs = model(**prepared_inputs_dict)
remat_outputs = remat_model(**prepared_inputs_dict)
# ensure that the dicts of outputs contain the same keys
self.assertEqual(outputs.keys(), remat_outputs.keys())
outputs = outputs.to_tuple()
remat_outputs = remat_outputs.to_tuple()
# ensure that the outputs remain precisely equal
for output, remat_output in zip(outputs, remat_outputs):
self.assertTrue((output == remat_output).all())
@require_flax
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment