# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import inspect import paddle from paddle.nn import Layer def fn_args_to_dict(func, *args, **kwargs): """ Inspect function `func` and its arguments for running, and extract a dict mapping between argument names and keys. """ if hasattr(inspect, 'getfullargspec'): (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = inspect.getfullargspec(func) else: (spec_args, spec_varargs, spec_varkw, spec_defaults) = inspect.getargspec(func) # add positional argument values init_dict = dict(zip(spec_args, args)) # add default argument values kwargs_dict = dict(zip(spec_args[-len(spec_defaults):], spec_defaults)) if spec_defaults else {} kwargs_dict.update(kwargs) init_dict.update(kwargs_dict) return init_dict class InitTrackerMeta(type(Layer)): """ This metaclass wraps the `__init__` method of a class to add `init_config` attribute for instances of that class, and `init_config` use a dict to track the initial configuration. If the class has `_wrap_init` method, it would be hooked after `__init__` and called as `_wrap_init(self, init_fn, init_args)`. Since InitTrackerMeta would be used as metaclass for pretrained model classes, which always are Layer and `type(Layer)` is not `type`, thus use `type(Layer)` rather than `type` as base class for it to avoid inheritance metaclass conflicts. """ def __init__(cls, name, bases, attrs): init_func = cls.__init__ # If attrs has `__init__`, wrap it using accessable `_wrap_init`. # Otherwise, no need to wrap again since the super cls has been wraped. # TODO: remove reduplicated tracker if using super cls `__init__` help_func = getattr(cls, '_wrap_init', None) if '__init__' in attrs else None cls.__init__ = InitTrackerMeta.init_and_track_conf(init_func, help_func) super(InitTrackerMeta, cls).__init__(name, bases, attrs) @staticmethod def init_and_track_conf(init_func, help_func=None): """ wraps `init_func` which is `__init__` method of a class to add `init_config` attribute for instances of that class. Args: init_func (callable): It should be the `__init__` method of a class. help_func (callable, optional): If provided, it would be hooked after `init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`. Default None. Returns: function: the wrapped function """ @functools.wraps(init_func) def __impl__(self, *args, **kwargs): # keep full configuration init_func(self, *args, **kwargs) # registed helper by `_wrap_init` if help_func: help_func(self, init_func, *args, **kwargs) self.init_config = kwargs if args: kwargs['init_args'] = args kwargs['init_class'] = self.__class__.__name__ return __impl__