Unverified Commit 83a7bb2a authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Implement `FlaxModelMixin` (#493)



* Implement `FlaxModelMixin`

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Update src/diffusers/modeling_flax_utils.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Rm incorrect docstring

* Add FlaxModelMixin to __init__.py

* make fix-copies
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 8b450969
...@@ -63,6 +63,7 @@ else: ...@@ -63,6 +63,7 @@ else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
if is_flax_available(): if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .schedulers import FlaxPNDMScheduler from .schedulers import FlaxPNDMScheduler
else: else:
from .utils.dummy_flax_objects import * # noqa F403 from .utils.dummy_flax_objects import * # noqa F403
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" ConfigMixinuration base class and utilities.""" """ ConfigMixinuration base class and utilities."""
import dataclasses
import functools import functools
import inspect import inspect
import json import json
...@@ -271,6 +272,11 @@ class ConfigMixin: ...@@ -271,6 +272,11 @@ class ConfigMixin:
# remove general kwargs if present in dict # remove general kwargs if present in dict
if "kwargs" in expected_keys: if "kwargs" in expected_keys:
expected_keys.remove("kwargs") expected_keys.remove("kwargs")
# remove flax interal keys
if hasattr(cls, "_flax_internal_args"):
for arg in cls._flax_internal_args:
expected_keys.remove(arg)
# remove keys to be ignored # remove keys to be ignored
if len(cls.ignore_for_config) > 0: if len(cls.ignore_for_config) > 0:
expected_keys = expected_keys - set(cls.ignore_for_config) expected_keys = expected_keys - set(cls.ignore_for_config)
...@@ -401,3 +407,44 @@ def register_to_config(init): ...@@ -401,3 +407,44 @@ def register_to_config(init):
getattr(self, "register_to_config")(**new_kwargs) getattr(self, "register_to_config")(**new_kwargs)
return inner_init return inner_init
def flax_register_to_config(cls):
original_init = cls.__init__
@functools.wraps(original_init)
def init(self, *args, **kwargs):
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`."
)
# Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
# Retrieve default values
fields = dataclasses.fields(self)
default_kwargs = {}
for field in fields:
# ignore flax specific attributes
if field.name in self._flax_internal_args:
continue
if type(field.default) == dataclasses._MISSING_TYPE:
default_kwargs[field.name] = None
else:
default_kwargs[field.name] = getattr(self, field.name)
# Make sure init_kwargs override default kwargs
new_kwargs = {**default_kwargs, **init_kwargs}
# Get positional arguments aligned with kwargs
for i, arg in enumerate(args):
name = fields[i].name
new_kwargs[name] = arg
getattr(self, "register_to_config")(**new_kwargs)
original_init(self, *args, **kwargs)
cls.__init__ = init
return cls
This diff is collapsed.
...@@ -4,6 +4,13 @@ ...@@ -4,6 +4,13 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class FlaxModelMixin(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject): class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment