Unverified Commit 83a7bb2a authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Implement `FlaxModelMixin` (#493)



* Implement `FlaxModelMixin`

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Rm incorrect docstring

* Add FlaxModelMixin to __init__.py

* make fix-copies
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 8b450969
......@@ -63,6 +63,7 @@ else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .schedulers import FlaxPNDMScheduler
else:
from .utils.dummy_flax_objects import * # noqa F403
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import dataclasses
import functools
import inspect
import json
......@@ -271,6 +272,11 @@ class ConfigMixin:
# remove general kwargs if present in dict
if "kwargs" in expected_keys:
expected_keys.remove("kwargs")
# remove flax interal keys
if hasattr(cls, "_flax_internal_args"):
for arg in cls._flax_internal_args:
expected_keys.remove(arg)
# remove keys to be ignored
if len(cls.ignore_for_config) > 0:
expected_keys = expected_keys - set(cls.ignore_for_config)
......@@ -401,3 +407,44 @@ def register_to_config(init):
getattr(self, "register_to_config")(**new_kwargs)
return inner_init
def flax_register_to_config(cls):
original_init = cls.__init__
@functools.wraps(original_init)
def init(self, *args, **kwargs):
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`."
)
# Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
# Retrieve default values
fields = dataclasses.fields(self)
default_kwargs = {}
for field in fields:
# ignore flax specific attributes
if field.name in self._flax_internal_args:
continue
if type(field.default) == dataclasses._MISSING_TYPE:
default_kwargs[field.name] = None
else:
default_kwargs[field.name] = getattr(self, field.name)
# Make sure init_kwargs override default kwargs
new_kwargs = {**default_kwargs, **init_kwargs}
# Get positional arguments aligned with kwargs
for i, arg in enumerate(args):
name = fields[i].name
new_kwargs[name] = arg
getattr(self, "register_to_config")(**new_kwargs)
original_init(self, *args, **kwargs)
cls.__init__ = init
return cls
This diff is collapsed.
......@@ -4,6 +4,13 @@
from ..utils import DummyObject, requires_backends
class FlaxModelMixin(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment