Unverified Commit fb5468a6 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Add `init_weights` method to `FlaxMixin` (#513)



* Add `init_weights` method to `FlaxMixin`

* Rn `random_state` -> `shape_state`

* `PRNGKey(0)` for `jax.eval_shape`

* No allow mismatched sizes

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* docstring diffusers
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent d144c46a
...@@ -20,7 +20,7 @@ from typing import Any, Dict, Union ...@@ -20,7 +20,7 @@ from typing import Any, Dict, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import msgpack.exceptions import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
...@@ -183,6 +183,9 @@ class FlaxModelMixin: ...@@ -183,6 +183,9 @@ class FlaxModelMixin:
```""" ```"""
return self._cast_floating_to(params, jnp.float16, mask) return self._cast_floating_to(params, jnp.float16, mask)
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -227,10 +230,6 @@ class FlaxModelMixin: ...@@ -227,10 +230,6 @@ class FlaxModelMixin:
cache_dir (`Union[str, os.PathLike]`, *optional*): cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used. standard cache should not be used.
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`): force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist. cached versions if they exist.
...@@ -394,6 +393,72 @@ class FlaxModelMixin: ...@@ -394,6 +393,72 @@ class FlaxModelMixin:
# flatten dicts # flatten dicts
state = flatten_dict(state) state = flatten_dict(state)
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
shape_state = flatten_dict(unfreeze(params_shape_tree))
missing_keys = required_params - set(state.keys())
unexpected_keys = set(state.keys()) - required_params
if missing_keys:
logger.warning(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
"Make sure to call model.init_weights to initialize the missing weights."
)
cls._missing_keys = missing_keys
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
for key in state.keys():
if key in shape_state and state[key].shape != shape_state[key].shape:
raise ValueError(
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
)
# remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys:
del state[unexpected_key]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
# dictionary of key: dtypes for the model params # dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state) param_dtypes = jax.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32 # extract keys of parameters not in jnp.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment