Unverified Commit d3bd9ac7 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax] improve large model init and loading (#16148)



* begin do_init

* add params_shape_tree

* raise error if params are accessed when do_init is False

* don't allow do_init=False when keys are missing

* make shape tree a property

* assign self._params at the end

* add test for do_init

* add do_init arg to all flax models

* fix param setting

* disbale do_init for composite models

* update test

* add do_init in FlaxBigBirdForMultipleChoice

* better names and errors

* improve test

* style

* add a warning when do_init=False

* remove extra if

* set params after _required_params

* add test for from_pretrained

* do_init => _do_init

* chage warning to info

* fix typo

* add params in init_weights

* add params to gpt neo init

* add params to init_weights

* update do_init test

* Trigger CI

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update template

* trigger CI

* style

* style

* fix template
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 6de4ee61
...@@ -21,8 +21,9 @@ import numpy as np ...@@ -21,8 +21,9 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -621,12 +622,13 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel): ...@@ -621,12 +622,13 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
...@@ -636,9 +638,19 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel): ...@@ -636,9 +638,19 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)[ random_params = self.module.init(
"params" rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False
] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
......
...@@ -20,7 +20,8 @@ from typing import Optional, Tuple, Union ...@@ -20,7 +20,8 @@ from typing import Optional, Tuple, Union
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -343,8 +344,15 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -343,8 +344,15 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None, input_shape: Optional[Tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
if not _do_init:
raise ValueError(
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size is not None:
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer) # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
...@@ -365,9 +373,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -365,9 +373,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
input_shape = ((1, encoder_input_length), (1, decoder_input_length)) input_shape = ((1, encoder_input_length), (1, decoder_input_length))
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape encoder_input_shape, decoder_input_shape = input_shape
# init input DeviceArrays # init input DeviceArrays
...@@ -390,7 +398,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -390,7 +398,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
inputs, inputs,
attention_mask, attention_mask,
...@@ -399,6 +407,16 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -399,6 +407,16 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -23,9 +23,10 @@ import numpy as np ...@@ -23,9 +23,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -919,12 +920,13 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): ...@@ -919,12 +920,13 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -935,7 +937,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): ...@@ -935,7 +937,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -943,6 +945,16 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): ...@@ -943,6 +945,16 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
decoder_attention_mask, decoder_attention_mask,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
def __call__( def __call__(
self, self,
......
...@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -282,8 +283,14 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -282,8 +283,14 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None, input_shape: Optional[Tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
if not _do_init:
raise ValueError(
"`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if input_shape is None: if input_shape is None:
num_channels = getattr(config.encoder, "num_channels", 3) num_channels = getattr(config.encoder, "num_channels", 3)
input_shape = ( input_shape = (
...@@ -301,9 +308,9 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -301,9 +308,9 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
) )
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape encoder_input_shape, decoder_input_shape = input_shape
# init input tensors # init input tensors
...@@ -325,7 +332,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -325,7 +332,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
pixel_values, pixel_values,
decoder_input_ids, decoder_input_ids,
...@@ -333,6 +340,16 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -333,6 +340,16 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -20,7 +20,8 @@ from typing import Optional, Tuple ...@@ -20,7 +20,8 @@ from typing import Optional, Tuple
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
from ...utils import add_start_docstrings, logging from ...utils import add_start_docstrings, logging
...@@ -225,15 +226,22 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel): ...@@ -225,15 +226,22 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None, input_shape: Optional[Tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
if not _do_init:
raise ValueError(
"`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if input_shape is None: if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor # init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4") input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
...@@ -245,7 +253,19 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel): ...@@ -245,7 +253,19 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"] random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[
"params"
]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__( def __call__(
self, self,
......
...@@ -18,8 +18,9 @@ from typing import Optional, Tuple ...@@ -18,8 +18,9 @@ from typing import Optional, Tuple
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
from ...modeling_flax_utils import ( from ...modeling_flax_utils import (
...@@ -407,20 +408,38 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel): ...@@ -407,20 +408,38 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): def __init__(
self,
config: ViTConfig,
input_shape=None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None: if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3) input_shape = (1, config.image_size, config.image_size, 3)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype) pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, pixel_values, return_dict=False)["params"] random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
......
...@@ -23,8 +23,9 @@ import flax ...@@ -23,8 +23,9 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
...@@ -858,19 +859,30 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): ...@@ -858,19 +859,30 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1024), input_shape: Tuple = (1, 1024),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs, **kwargs,
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_values = jnp.zeros(input_shape, dtype="i4") input_values = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_values) attention_mask = jnp.ones_like(input_values)
params_rng, dropout_rng = jax.random.split(rng, 2) params_rng, dropout_rng = jax.random.split(rng, 2)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
def __call__( def __call__(
......
...@@ -25,9 +25,10 @@ import numpy as np ...@@ -25,9 +25,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -561,12 +562,13 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel): ...@@ -561,12 +562,13 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
...@@ -589,7 +591,17 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel): ...@@ -589,7 +591,17 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
else: else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"] random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length): def init_cache(self, batch_size, max_length):
r""" r"""
......
...@@ -23,7 +23,8 @@ import numpy as np ...@@ -23,7 +23,8 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -586,12 +587,18 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -586,12 +587,18 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
module_class: nn.Module = None module_class: nn.Module = None
def __init__( def __init__(
self, config: {{cookiecutter.camelcase_modelname}}Config, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs self,
config: {{cookiecutter.camelcase_modelname}}Config,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
...@@ -602,10 +609,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -602,10 +609,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
...@@ -1130,9 +1147,10 @@ from typing import Callable, Optional, Tuple ...@@ -1130,9 +1147,10 @@ from typing import Callable, Optional, Tuple
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -2031,12 +2049,13 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -2031,12 +2049,13 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule # make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
...@@ -2052,7 +2071,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -2052,7 +2071,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -2062,6 +2081,16 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -2062,6 +2081,16 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -43,7 +43,7 @@ if is_flax_available(): ...@@ -43,7 +43,7 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import ( from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
...@@ -904,6 +904,93 @@ class FlaxModelTesterMixin: ...@@ -904,6 +904,93 @@ class FlaxModelTesterMixin:
else: else:
_check_attentions_validity(outputs.attentions) _check_attentions_validity(outputs.attentions)
def test_no_automatic_init(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
model = model_class(config, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if we params can be properly initialized when calling init_weights
params = model.init_weights(model.key, model.input_shape)
self.assertIsInstance(params, FrozenDict)
# Check if all required parmas are initialized
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
flat_params = flatten_dict(unfreeze(params))
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
)
# Check that setting params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
model.params = params
# Check if we can do a forward pass
inputs_dict["output_hidden_states"] = True
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
model(**inputs, params=params)
def test_from_pretrained_with_no_automatic_init(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
def _assert_all_params_initialised(model, params):
# Check if all required parmas are loaded
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
flat_params = flatten_dict(unfreeze(params))
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
)
for model_class in self.all_model_classes:
# init the model
model = model_class(config)
# save the model in the temporary directory
# load the saved model with _do_init=False
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
# Check that setting params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
model.params = params
# Check if init_weights initializes missing keys from from_pretrained
flat_params = flatten_dict(unfreeze(params))
random_key = random.choice(list(flat_params.keys()))
flat_params.pop(random_key)
params = freeze(unflatten_dict(flat_params))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=params)
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
params = model.init_weights(model.key, model.input_shape, params=params)
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
@require_flax @require_flax
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment