Unverified Commit d3bd9ac7 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax] improve large model init and loading (#16148)



* begin do_init

* add params_shape_tree

* raise error if params are accessed when do_init is False

* don't allow do_init=False when keys are missing

* make shape tree a property

* assign self._params at the end

* add test for do_init

* add do_init arg to all flax models

* fix param setting

* disbale do_init for composite models

* update test

* add do_init in FlaxBigBirdForMultipleChoice

* better names and errors

* improve test

* style

* add a warning when do_init=False

* remove extra if

* set params after _required_params

* add test for from_pretrained

* do_init => _do_init

* chage warning to info

* fix typo

* add params in init_weights

* add params to gpt neo init

* add params to init_weights

* update do_init test

* Trigger CI

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update template

* trigger CI

* style

* style

* fix template
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 6de4ee61
...@@ -140,7 +140,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -140,7 +140,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor # init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4") input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
......
...@@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
base_model_prefix = "" base_model_prefix = ""
main_input_name = "input_ids" main_input_name = "input_ids"
_auto_class = None _auto_class = None
_missing_keys = set()
def __init__( def __init__(
self, self,
...@@ -98,6 +99,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -98,6 +99,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
): ):
if config is None: if config is None:
raise ValueError("config cannot be None") raise ValueError("config cannot be None")
...@@ -112,15 +114,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -112,15 +114,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Those are public as their type is generic to every derived classes. # Those are public as their type is generic to every derived classes.
self.key = PRNGKey(seed) self.key = PRNGKey(seed)
self.dtype = dtype self.dtype = dtype
self.input_shape = input_shape
# randomly initialized parameters # To check if the model was intialized automatically.
random_params = self.init_weights(self.key, input_shape) self._is_initialized = _do_init
if _do_init:
# randomly initialized parameters
random_params = self.init_weights(self.key, input_shape)
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
else:
init_fn = partial(self.init_weights, input_shape=input_shape)
params_shape_tree = jax.eval_shape(init_fn, self.key)
logger.info(
"Model weights are not initialized as `_do_init` is set to `False`. "
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
)
# get the shape of the parameters
self._params_shape_tree = params_shape_tree
# save required_params as set # save required_params as set
self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
self.params = random_params
# initialize the parameters
if _do_init:
self.params = random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}") raise NotImplementedError(f"init method has to be implemented for {self}")
@classmethod @classmethod
...@@ -147,14 +169,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -147,14 +169,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
@property @property
def params(self) -> Union[Dict, FrozenDict]: def params(self) -> Union[Dict, FrozenDict]:
if not self._is_initialized:
raise ValueError(
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
"You must call `init_weights` manually and store the params outside of the model and "
"pass it explicitly where needed."
)
return self._params return self._params
@property @property
def required_params(self) -> Set: def required_params(self) -> Set:
return self._required_params return self._required_params
@property
def params_shape_tree(self) -> Dict:
return self._params_shape_tree
@params.setter @params.setter
def params(self, params: Union[Dict, FrozenDict]): def params(self, params: Union[Dict, FrozenDict]):
# don't set params if the model is not initialized
if not self._is_initialized:
raise ValueError(
"`params` cannot be set from model when the model is created with `_do_init=False`. "
"You store the params outside of the model."
)
if isinstance(params, FrozenDict): if isinstance(params, FrozenDict):
params = unfreeze(params) params = unfreeze(params)
param_keys = set(flatten_dict(params).keys()) param_keys = set(flatten_dict(params).keys())
...@@ -417,6 +456,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -417,6 +456,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
if from_pipeline is not None: if from_pipeline is not None:
...@@ -553,7 +593,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -553,7 +593,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
resolved_archive_file = None resolved_archive_file = None
# init random models # init random models
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
if from_pt: if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
...@@ -577,25 +617,36 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -577,25 +617,36 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# make sure all arrays are stored as jnp.arrays # make sure all arrays are stored as jnp.arrays
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261 # https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(jnp.array, state) if _do_init:
state = jax.tree_util.tree_map(jnp.array, state)
else:
# keep the params on CPU if we don't want to initialize
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
# if model is base model only use model_prefix key # if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix] state = state[cls.base_model_prefix]
# if model is head model and we are loading weights from base model # if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix # we initialize new params dict with base_model_prefix
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state: if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
state = {cls.base_model_prefix: state} state = {cls.base_model_prefix: state}
# flatten dicts # flatten dicts
state = flatten_dict(state) state = flatten_dict(state)
random_state = flatten_dict(unfreeze(model.params)) random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
missing_keys = model.required_params - set(state.keys()) missing_keys = model.required_params - set(state.keys())
unexpected_keys = set(state.keys()) - model.required_params unexpected_keys = set(state.keys()) - model.required_params
if missing_keys and not _do_init:
logger.warn(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
f"Make sure to call model.init_weights to initialize the missing weights."
)
cls._missing_keys = missing_keys
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model. # matching the weights in the model.
mismatched_keys = [] mismatched_keys = []
...@@ -612,9 +663,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -612,9 +663,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"model." "model."
) )
# add missing keys as random parameters # add missing keys as random parameters if we are initializing
for missing_key in missing_keys: if missing_keys and _do_init:
state[missing_key] = random_state[missing_key] for missing_key in missing_keys:
state[missing_key] = random_state[missing_key]
# remove unexpected keys to not be saved again # remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys: for unexpected_key in unexpected_keys:
...@@ -680,10 +732,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -680,10 +732,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
) )
# set correct parameters if _do_init:
model.params = unflatten_dict(state) # set correct parameters
model.params = unflatten_dict(state)
return model return model
else:
return model, unflatten_dict(state)
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs): def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
""" """
......
...@@ -21,8 +21,9 @@ import flax ...@@ -21,8 +21,9 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -522,12 +523,13 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): ...@@ -522,12 +523,13 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
...@@ -537,9 +539,19 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): ...@@ -537,9 +539,19 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[ random_params = self.module.init(
"params" rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
......
...@@ -24,9 +24,10 @@ import numpy as np ...@@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -912,12 +913,13 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): ...@@ -912,12 +913,13 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
...@@ -933,7 +935,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): ...@@ -933,7 +935,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -943,6 +945,16 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): ...@@ -943,6 +945,16 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
...@@ -1737,14 +1749,15 @@ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel): ...@@ -1737,14 +1749,15 @@ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
config.is_decoder = True config.is_decoder = True
config.is_encoder_decoder = False config.is_encoder_decoder = False
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
......
...@@ -22,8 +22,9 @@ import flax ...@@ -22,8 +22,9 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
FlaxBaseModelOutput, FlaxBaseModelOutput,
...@@ -591,13 +592,21 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): ...@@ -591,13 +592,21 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): def __init__(
self,
config: BeitConfig,
input_shape=None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None: if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3) input_shape = (1, config.image_size, config.image_size, 3)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype) pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
...@@ -605,7 +614,17 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): ...@@ -605,7 +614,17 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
dropout_rng, droppath_rng = jax.random.split(dropout_rng) dropout_rng, droppath_rng = jax.random.split(dropout_rng)
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
return self.module.init(rngs, pixel_values, return_dict=False)["params"] random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
......
...@@ -21,8 +21,9 @@ import flax ...@@ -21,8 +21,9 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -616,12 +617,18 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -616,12 +617,18 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
module_class: nn.Module = None module_class: nn.Module = None
def __init__( def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs self,
config: BertConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
...@@ -632,10 +639,20 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -632,10 +639,20 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
......
...@@ -21,8 +21,9 @@ import flax ...@@ -21,8 +21,9 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -1420,6 +1421,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1420,6 +1421,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
input_shape: Optional[tuple] = None, input_shape: Optional[tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
...@@ -1428,9 +1430,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1428,9 +1430,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
elif input_shape is None: elif input_shape is None:
input_shape = (1, 1) input_shape = (1, 1)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
...@@ -1441,10 +1443,20 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1441,10 +1443,20 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
...@@ -1897,13 +1909,14 @@ class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel): ...@@ -1897,13 +1909,14 @@ class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel):
input_shape: Optional[tuple] = None, input_shape: Optional[tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
if config.attention_type == "block_sparse" and input_shape is None: if config.attention_type == "block_sparse" and input_shape is None:
input_shape = (1, 1, 12 * config.block_size) input_shape = (1, 1, 12 * config.block_size)
elif input_shape is None: elif input_shape is None:
input_shape = (1, 1) input_shape = (1, 1)
super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
overwrite_call_docstring( overwrite_call_docstring(
......
...@@ -24,9 +24,10 @@ import numpy as np ...@@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -887,12 +888,13 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): ...@@ -887,12 +888,13 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule # make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
...@@ -908,7 +910,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): ...@@ -908,7 +910,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -918,6 +920,16 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): ...@@ -918,6 +920,16 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -25,9 +25,10 @@ import numpy as np ...@@ -25,9 +25,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -885,12 +886,13 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): ...@@ -885,12 +886,13 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule # make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
...@@ -906,7 +908,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): ...@@ -906,7 +908,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -916,6 +918,16 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): ...@@ -916,6 +918,16 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -19,9 +19,10 @@ import flax ...@@ -19,9 +19,10 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
...@@ -585,12 +586,18 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): ...@@ -585,12 +586,18 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
module_class: nn.Module = None module_class: nn.Module = None
def __init__( def __init__(
self, config: CLIPTextConfig, input_shape=(1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs self,
config: CLIPTextConfig,
input_shape=(1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor # init input tensor
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
...@@ -599,7 +606,17 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): ...@@ -599,7 +606,17 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, position_ids)["params"] random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__( def __call__(
self, self,
...@@ -654,21 +671,32 @@ class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel): ...@@ -654,21 +671,32 @@ class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None, input_shape: Optional[Tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
if input_shape is None: if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3) input_shape = (1, config.image_size, config.image_size, 3)
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor # init input tensor
pixel_values = jax.random.normal(rng, input_shape) pixel_values = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, pixel_values)["params"] random_params = self.module.init(rngs, pixel_values)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__( def __call__(
self, self,
...@@ -714,14 +742,15 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -714,14 +742,15 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None, input_shape: Optional[Tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
if input_shape is None: if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor # init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4") input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
...@@ -732,7 +761,17 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -732,7 +761,17 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"] random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__( def __call__(
self, self,
......
...@@ -21,7 +21,8 @@ import numpy as np ...@@ -21,7 +21,8 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -428,12 +429,13 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -428,12 +429,13 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
...@@ -441,7 +443,17 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -441,7 +443,17 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"] random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
......
...@@ -21,8 +21,9 @@ import flax ...@@ -21,8 +21,9 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -541,12 +542,13 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -541,12 +542,13 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
...@@ -557,10 +559,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -557,10 +559,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
......
...@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -315,11 +316,17 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -315,11 +316,17 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None, input_shape: Optional[Tuple] = None,
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
if input_shape is None: if input_shape is None:
input_shape = ((1, 1), (1, 1)) input_shape = ((1, 1), (1, 1))
if not _do_init:
raise ValueError(
"`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError( raise ValueError(
...@@ -330,9 +337,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -330,9 +337,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
) )
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape encoder_input_shape, decoder_input_shape = input_shape
# init input tensors # init input tensors
...@@ -356,7 +363,7 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -356,7 +363,7 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -366,6 +373,16 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -366,6 +373,16 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -18,9 +18,10 @@ from typing import Any, Optional, Tuple ...@@ -18,9 +18,10 @@ from typing import Any, Optional, Tuple
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -394,12 +395,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -394,12 +395,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs, **kwargs,
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
...@@ -422,7 +424,17 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -422,7 +424,17 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
else: else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"] random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length): def init_cache(self, batch_size, max_length):
r""" r"""
......
...@@ -19,9 +19,10 @@ from typing import Optional, Tuple ...@@ -19,9 +19,10 @@ from typing import Optional, Tuple
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
...@@ -353,12 +354,13 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): ...@@ -353,12 +354,13 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs, **kwargs,
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
...@@ -366,7 +368,17 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): ...@@ -366,7 +368,17 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length): def init_cache(self, batch_size, max_length):
r""" r"""
......
...@@ -21,9 +21,10 @@ import numpy as np ...@@ -21,9 +21,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
...@@ -373,12 +374,13 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel): ...@@ -373,12 +374,13 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs, **kwargs,
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
...@@ -401,7 +403,17 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel): ...@@ -401,7 +403,17 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
else: else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"] random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length): def init_cache(self, batch_size, max_length):
r""" r"""
......
...@@ -24,9 +24,10 @@ import numpy as np ...@@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -882,12 +883,13 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): ...@@ -882,12 +883,13 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule # make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
...@@ -903,7 +905,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): ...@@ -903,7 +905,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -913,6 +915,16 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): ...@@ -913,6 +915,16 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -24,9 +24,10 @@ import numpy as np ...@@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -951,12 +952,13 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): ...@@ -951,12 +952,13 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule # make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
...@@ -972,7 +974,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): ...@@ -972,7 +974,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -982,6 +984,16 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): ...@@ -982,6 +984,16 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
......
...@@ -25,9 +25,10 @@ import numpy as np ...@@ -25,9 +25,10 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -901,12 +902,13 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): ...@@ -901,12 +902,13 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1), input_shape: Tuple[int] = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
...@@ -920,7 +922,7 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): ...@@ -920,7 +922,7 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -930,6 +932,16 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): ...@@ -930,6 +932,16 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids, decoder_position_ids,
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs): def init_cache(self, batch_size, max_length, encoder_outputs):
r""" r"""
Args: Args:
......
...@@ -19,8 +19,9 @@ import numpy as np ...@@ -19,8 +19,9 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -585,12 +586,13 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -585,12 +586,13 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1), input_shape: Tuple = (1, 1),
seed: int = 0, seed: int = 0,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs **kwargs
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.ones_like(input_ids)
...@@ -601,10 +603,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -601,10 +603,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init( random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment