Unverified Commit 485bbe79 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax] Add remat (gradient checkpointing) (#17843)

* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* make fix-copies

* fix big-bird, electra, roberta

* cookie-cutter

* fix flax big-bird

* move test to common
parent 664688b9
...@@ -235,6 +235,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -235,6 +235,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}") raise NotImplementedError(f"init method has to be implemented for {self}")
def enable_gradient_checkpointing(self):
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
@classmethod @classmethod
def _from_config(cls, config, **kwargs): def _from_config(cls, config, **kwargs):
""" """
......
...@@ -23,6 +23,7 @@ import jax ...@@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
...@@ -56,6 +57,8 @@ _CHECKPOINT_FOR_DOC = "bert-base-uncased" ...@@ -56,6 +57,8 @@ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig" _CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer" _TOKENIZER_FOR_DOC = "BertTokenizer"
remat = nn_partitioning.remat
@flax.struct.dataclass @flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput): class FlaxBertForPreTrainingOutput(ModelOutput):
...@@ -544,11 +547,19 @@ class FlaxBertLayer(nn.Module): ...@@ -544,11 +547,19 @@ class FlaxBertLayer(nn.Module):
class FlaxBertLayerCollection(nn.Module): class FlaxBertLayerCollection(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layers = [ if self.gradient_checkpointing:
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
] self.layers = [
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__( def __call__(
self, self,
...@@ -582,12 +593,12 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -582,12 +593,12 @@ class FlaxBertLayerCollection(nn.Module):
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
init_cache=init_cache, init_cache,
deterministic=deterministic, deterministic,
output_attentions=output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -617,9 +628,14 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -617,9 +628,14 @@ class FlaxBertLayerCollection(nn.Module):
class FlaxBertEncoder(nn.Module): class FlaxBertEncoder(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype) self.layer = FlaxBertLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__( def __call__(
self, self,
...@@ -756,11 +772,24 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -756,11 +772,24 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(
config=config,
dtype=dtype,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -925,10 +954,15 @@ class FlaxBertModule(nn.Module): ...@@ -925,10 +954,15 @@ class FlaxBertModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype) self.encoder = FlaxBertEncoder(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1003,9 +1037,14 @@ append_call_sample_docstring( ...@@ -1003,9 +1037,14 @@ append_call_sample_docstring(
class FlaxBertForPreTrainingModule(nn.Module): class FlaxBertForPreTrainingModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1099,9 +1138,15 @@ append_replace_return_docstrings( ...@@ -1099,9 +1138,15 @@ append_replace_return_docstrings(
class FlaxBertForMaskedLMModule(nn.Module): class FlaxBertForMaskedLMModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1161,9 +1206,14 @@ append_call_sample_docstring( ...@@ -1161,9 +1206,14 @@ append_call_sample_docstring(
class FlaxBertForNextSentencePredictionModule(nn.Module): class FlaxBertForNextSentencePredictionModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
def __call__( def __call__(
...@@ -1248,9 +1298,14 @@ append_replace_return_docstrings( ...@@ -1248,9 +1298,14 @@ append_replace_return_docstrings(
class FlaxBertForSequenceClassificationModule(nn.Module): class FlaxBertForSequenceClassificationModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = ( classifier_dropout = (
self.config.classifier_dropout self.config.classifier_dropout
if self.config.classifier_dropout is not None if self.config.classifier_dropout is not None
...@@ -1324,9 +1379,14 @@ append_call_sample_docstring( ...@@ -1324,9 +1379,14 @@ append_call_sample_docstring(
class FlaxBertForMultipleChoiceModule(nn.Module): class FlaxBertForMultipleChoiceModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype)
...@@ -1399,9 +1459,15 @@ append_call_sample_docstring( ...@@ -1399,9 +1459,15 @@ append_call_sample_docstring(
class FlaxBertForTokenClassificationModule(nn.Module): class FlaxBertForTokenClassificationModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = ( classifier_dropout = (
self.config.classifier_dropout self.config.classifier_dropout
if self.config.classifier_dropout is not None if self.config.classifier_dropout is not None
...@@ -1468,9 +1534,15 @@ append_call_sample_docstring( ...@@ -1468,9 +1534,15 @@ append_call_sample_docstring(
class FlaxBertForQuestionAnsweringModule(nn.Module): class FlaxBertForQuestionAnsweringModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__( def __call__(
...@@ -1539,9 +1611,15 @@ append_call_sample_docstring( ...@@ -1539,9 +1611,15 @@ append_call_sample_docstring(
class FlaxBertForCausalLMModule(nn.Module): class FlaxBertForCausalLMModule(nn.Module):
config: BertConfig config: BertConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
......
...@@ -23,6 +23,7 @@ import jax ...@@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
...@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base" ...@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base"
_CONFIG_FOR_DOC = "BigBirdConfig" _CONFIG_FOR_DOC = "BigBirdConfig"
_TOKENIZER_FOR_DOC = "BigBirdTokenizer" _TOKENIZER_FOR_DOC = "BigBirdTokenizer"
remat = nn_partitioning.remat
@flax.struct.dataclass @flax.struct.dataclass
class FlaxBigBirdForPreTrainingOutput(ModelOutput): class FlaxBigBirdForPreTrainingOutput(ModelOutput):
...@@ -1368,12 +1371,20 @@ class FlaxBigBirdLayer(nn.Module): ...@@ -1368,12 +1371,20 @@ class FlaxBigBirdLayer(nn.Module):
class FlaxBigBirdLayerCollection(nn.Module): class FlaxBigBirdLayerCollection(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layers = [ if self.gradient_checkpointing:
FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7))
for i in range(self.config.num_hidden_layers) self.layers = [
] FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird
def __call__( def __call__(
...@@ -1408,12 +1419,12 @@ class FlaxBigBirdLayerCollection(nn.Module): ...@@ -1408,12 +1419,12 @@ class FlaxBigBirdLayerCollection(nn.Module):
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
init_cache=init_cache, init_cache,
deterministic=deterministic, deterministic,
output_attentions=output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -1444,9 +1455,14 @@ class FlaxBigBirdLayerCollection(nn.Module): ...@@ -1444,9 +1455,14 @@ class FlaxBigBirdLayerCollection(nn.Module):
class FlaxBigBirdEncoder(nn.Module): class FlaxBigBirdEncoder(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layer = FlaxBigBirdLayerCollection(self.config, dtype=self.dtype) self.layer = FlaxBigBirdLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__( def __call__(
self, self,
...@@ -1559,9 +1575,10 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1559,9 +1575,10 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
if config.attention_type == "block_sparse" and input_shape is None: if config.attention_type == "block_sparse" and input_shape is None:
input_shape = (1, 12 * config.block_size) input_shape = (1, 12 * config.block_size)
elif input_shape is None: elif input_shape is None:
...@@ -1569,6 +1586,14 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1569,6 +1586,14 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
...@@ -1735,10 +1760,13 @@ class FlaxBigBirdModule(nn.Module): ...@@ -1735,10 +1760,13 @@ class FlaxBigBirdModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype) self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype) self.encoder = FlaxBigBirdEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.pooler = nn.Dense( self.pooler = nn.Dense(
self.config.hidden_size, self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
...@@ -1812,9 +1840,14 @@ append_call_sample_docstring( ...@@ -1812,9 +1840,14 @@ append_call_sample_docstring(
class FlaxBigBirdForPreTrainingModule(nn.Module): class FlaxBigBirdForPreTrainingModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) self.bert = FlaxBigBirdModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype) self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1910,9 +1943,15 @@ append_replace_return_docstrings( ...@@ -1910,9 +1943,15 @@ append_replace_return_docstrings(
class FlaxBigBirdForMaskedLMModule(nn.Module): class FlaxBigBirdForMaskedLMModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.bert = FlaxBigBirdModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1999,9 +2038,12 @@ class FlaxBigBirdClassificationHead(nn.Module): ...@@ -1999,9 +2038,12 @@ class FlaxBigBirdClassificationHead(nn.Module):
class FlaxBigBirdForSequenceClassificationModule(nn.Module): class FlaxBigBirdForSequenceClassificationModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) self.bert = FlaxBigBirdModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype) self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -2067,9 +2109,14 @@ append_call_sample_docstring( ...@@ -2067,9 +2109,14 @@ append_call_sample_docstring(
class FlaxBigBirdForMultipleChoiceModule(nn.Module): class FlaxBigBirdForMultipleChoiceModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) self.bert = FlaxBigBirdModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype)
...@@ -2162,9 +2209,15 @@ append_call_sample_docstring( ...@@ -2162,9 +2209,15 @@ append_call_sample_docstring(
class FlaxBigBirdForTokenClassificationModule(nn.Module): class FlaxBigBirdForTokenClassificationModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.bert = FlaxBigBirdModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = ( classifier_dropout = (
self.config.classifier_dropout self.config.classifier_dropout
if self.config.classifier_dropout is not None if self.config.classifier_dropout is not None
...@@ -2255,10 +2308,16 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module): ...@@ -2255,10 +2308,16 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
add_pooling_layer: bool = False add_pooling_layer: bool = False
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.config.num_labels = 2 self.config.num_labels = 2
self.bert = FlaxBigBirdModule(self.config, dtype=self.dtype, add_pooling_layer=self.add_pooling_layer) self.bert = FlaxBigBirdModule(
self.config,
dtype=self.dtype,
add_pooling_layer=self.add_pooling_layer,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype) self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -2414,9 +2473,15 @@ append_call_sample_docstring( ...@@ -2414,9 +2473,15 @@ append_call_sample_docstring(
class FlaxBigBirdForCausalLMModule(nn.Module): class FlaxBigBirdForCausalLMModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.bert = FlaxBigBirdModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
......
...@@ -23,6 +23,7 @@ import jax ...@@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
...@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" ...@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
_CONFIG_FOR_DOC = "ElectraConfig" _CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer" _TOKENIZER_FOR_DOC = "ElectraTokenizer"
remat = nn_partitioning.remat
@flax.struct.dataclass @flax.struct.dataclass
class FlaxElectraForPreTrainingOutput(ModelOutput): class FlaxElectraForPreTrainingOutput(ModelOutput):
...@@ -521,11 +524,20 @@ class FlaxElectraLayer(nn.Module): ...@@ -521,11 +524,20 @@ class FlaxElectraLayer(nn.Module):
class FlaxElectraLayerCollection(nn.Module): class FlaxElectraLayerCollection(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layers = [ if self.gradient_checkpointing:
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
] self.layers = [
FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__( def __call__(
self, self,
...@@ -559,12 +571,12 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -559,12 +571,12 @@ class FlaxElectraLayerCollection(nn.Module):
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
init_cache=init_cache, init_cache,
deterministic=deterministic, deterministic,
output_attentions=output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -595,9 +607,14 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -595,9 +607,14 @@ class FlaxElectraLayerCollection(nn.Module):
class FlaxElectraEncoder(nn.Module): class FlaxElectraEncoder(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype) self.layer = FlaxElectraLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__( def __call__(
self, self,
...@@ -675,11 +692,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -675,11 +692,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
...@@ -845,12 +871,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -845,12 +871,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
class FlaxElectraModule(nn.Module): class FlaxElectraModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
if self.config.embedding_size != self.config.hidden_size: if self.config.embedding_size != self.config.hidden_size:
self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) self.encoder = FlaxElectraEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
def __call__( def __call__(
self, self,
...@@ -925,9 +954,12 @@ class FlaxElectraTiedDense(nn.Module): ...@@ -925,9 +954,12 @@ class FlaxElectraTiedDense(nn.Module):
class FlaxElectraForMaskedLMModule(nn.Module): class FlaxElectraForMaskedLMModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
...@@ -989,9 +1021,12 @@ append_call_sample_docstring( ...@@ -989,9 +1021,12 @@ append_call_sample_docstring(
class FlaxElectraForPreTrainingModule(nn.Module): class FlaxElectraForPreTrainingModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1074,9 +1109,12 @@ append_replace_return_docstrings( ...@@ -1074,9 +1109,12 @@ append_replace_return_docstrings(
class FlaxElectraForTokenClassificationModule(nn.Module): class FlaxElectraForTokenClassificationModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
classifier_dropout = ( classifier_dropout = (
self.config.classifier_dropout self.config.classifier_dropout
if self.config.classifier_dropout is not None if self.config.classifier_dropout is not None
...@@ -1218,9 +1256,12 @@ class FlaxElectraSequenceSummary(nn.Module): ...@@ -1218,9 +1256,12 @@ class FlaxElectraSequenceSummary(nn.Module):
class FlaxElectraForMultipleChoiceModule(nn.Module): class FlaxElectraForMultipleChoiceModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype) self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
self.classifier = nn.Dense(1, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype)
...@@ -1297,9 +1338,12 @@ append_call_sample_docstring( ...@@ -1297,9 +1338,12 @@ append_call_sample_docstring(
class FlaxElectraForQuestionAnsweringModule(nn.Module): class FlaxElectraForQuestionAnsweringModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__( def __call__(
...@@ -1392,9 +1436,12 @@ class FlaxElectraClassificationHead(nn.Module): ...@@ -1392,9 +1436,12 @@ class FlaxElectraClassificationHead(nn.Module):
class FlaxElectraForSequenceClassificationModule(nn.Module): class FlaxElectraForSequenceClassificationModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1457,9 +1504,12 @@ append_call_sample_docstring( ...@@ -1457,9 +1504,12 @@ append_call_sample_docstring(
class FlaxElectraForCausalLMModule(nn.Module): class FlaxElectraForCausalLMModule(nn.Module):
config: ElectraConfig config: ElectraConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
......
...@@ -21,6 +21,7 @@ import jax ...@@ -21,6 +21,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
...@@ -47,6 +48,8 @@ _CHECKPOINT_FOR_DOC = "roberta-base" ...@@ -47,6 +48,8 @@ _CHECKPOINT_FOR_DOC = "roberta-base"
_CONFIG_FOR_DOC = "RobertaConfig" _CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer" _TOKENIZER_FOR_DOC = "RobertaTokenizer"
remat = nn_partitioning.remat
def create_position_ids_from_input_ids(input_ids, padding_idx): def create_position_ids_from_input_ids(input_ids, padding_idx):
""" """
...@@ -511,11 +514,20 @@ class FlaxRobertaLayer(nn.Module): ...@@ -511,11 +514,20 @@ class FlaxRobertaLayer(nn.Module):
class FlaxRobertaLayerCollection(nn.Module): class FlaxRobertaLayerCollection(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layers = [ if self.gradient_checkpointing:
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
] self.layers = [
FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__( def __call__(
self, self,
...@@ -549,12 +561,12 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -549,12 +561,12 @@ class FlaxRobertaLayerCollection(nn.Module):
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
init_cache=init_cache, init_cache,
deterministic=deterministic, deterministic,
output_attentions=output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -585,9 +597,14 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -585,9 +597,14 @@ class FlaxRobertaLayerCollection(nn.Module):
class FlaxRobertaEncoder(nn.Module): class FlaxRobertaEncoder(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype) self.layer = FlaxRobertaLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__( def __call__(
self, self,
...@@ -719,11 +736,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -719,11 +736,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -889,10 +915,15 @@ class FlaxRobertaModule(nn.Module): ...@@ -889,10 +915,15 @@ class FlaxRobertaModule(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype) self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype) self.encoder = FlaxRobertaEncoder(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -967,9 +998,15 @@ append_call_sample_docstring( ...@@ -967,9 +998,15 @@ append_call_sample_docstring(
class FlaxRobertaForMaskedLMModule(nn.Module): class FlaxRobertaForMaskedLMModule(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.roberta = FlaxRobertaModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1034,9 +1071,15 @@ append_call_sample_docstring( ...@@ -1034,9 +1071,15 @@ append_call_sample_docstring(
class FlaxRobertaForSequenceClassificationModule(nn.Module): class FlaxRobertaForSequenceClassificationModule(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype) self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1101,9 +1144,14 @@ append_call_sample_docstring( ...@@ -1101,9 +1144,14 @@ append_call_sample_docstring(
class FlaxRobertaForMultipleChoiceModule(nn.Module): class FlaxRobertaForMultipleChoiceModule(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype) self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype)
...@@ -1181,9 +1229,15 @@ append_call_sample_docstring( ...@@ -1181,9 +1229,15 @@ append_call_sample_docstring(
class FlaxRobertaForTokenClassificationModule(nn.Module): class FlaxRobertaForTokenClassificationModule(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = ( classifier_dropout = (
self.config.classifier_dropout self.config.classifier_dropout
if self.config.classifier_dropout is not None if self.config.classifier_dropout is not None
...@@ -1255,9 +1309,15 @@ append_call_sample_docstring( ...@@ -1255,9 +1309,15 @@ append_call_sample_docstring(
class FlaxRobertaForQuestionAnsweringModule(nn.Module): class FlaxRobertaForQuestionAnsweringModule(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.roberta = FlaxRobertaModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__( def __call__(
...@@ -1326,9 +1386,15 @@ append_call_sample_docstring( ...@@ -1326,9 +1386,15 @@ append_call_sample_docstring(
class FlaxRobertaForCausalLMModule(nn.Module): class FlaxRobertaForCausalLMModule(nn.Module):
config: RobertaConfig config: RobertaConfig
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.roberta = FlaxRobertaModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
......
...@@ -25,6 +25,7 @@ import jax ...@@ -25,6 +25,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -126,6 +127,8 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" ...@@ -126,6 +127,8 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
""" """
remat = nn_partitioning.remat
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
...@@ -507,11 +510,19 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module): ...@@ -507,11 +510,19 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layers = [ if self.gradient_checkpointing:
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7))
] self.layers = [
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__( def __call__(
self, self,
...@@ -545,12 +556,12 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ...@@ -545,12 +556,12 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
init_cache=init_cache, init_cache,
deterministic=deterministic, deterministic,
output_attentions=output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -581,9 +592,10 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ...@@ -581,9 +592,10 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype) self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
def __call__( def __call__(
self, self,
...@@ -725,11 +737,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -725,11 +737,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True, _do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
...@@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): ...@@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype) self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype)
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype) self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype) self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase ...@@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase
class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1030,9 +1053,10 @@ append_call_sample_docstring( ...@@ -1030,9 +1053,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -1092,9 +1116,10 @@ append_call_sample_docstring( ...@@ -1092,9 +1116,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense( self.classifier = nn.Dense(
self.config.num_labels, self.config.num_labels,
...@@ -1163,9 +1188,10 @@ append_call_sample_docstring( ...@@ -1163,9 +1188,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype)
...@@ -1238,9 +1264,10 @@ append_call_sample_docstring( ...@@ -1238,9 +1264,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
...@@ -1302,9 +1329,10 @@ append_call_sample_docstring( ...@@ -1302,9 +1329,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__( def __call__(
...@@ -1373,9 +1401,10 @@ append_call_sample_docstring( ...@@ -1373,9 +1401,10 @@ append_call_sample_docstring(
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self): def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__( def __call__(
......
...@@ -1099,6 +1099,33 @@ class FlaxModelTesterMixin: ...@@ -1099,6 +1099,33 @@ class FlaxModelTesterMixin:
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2))) self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
remat_model = model_class(config)
try:
remat_model.enable_gradient_checkpointing()
except NotImplementedError:
continue
outputs = model(**prepared_inputs_dict)
remat_outputs = remat_model(**prepared_inputs_dict)
# ensure that the dicts of outputs contain the same keys
self.assertEqual(outputs.keys(), remat_outputs.keys())
outputs = outputs.to_tuple()
remat_outputs = remat_outputs.to_tuple()
# ensure that the outputs remain precisely equal
for output, remat_output in zip(outputs, remat_outputs):
self.assertTrue((output == remat_output).all())
@require_flax @require_flax
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment