Unverified Commit 6eae3f78 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add `BackboneMixin` (#20660)



* add BackboneBaseModel

* add BackboneBaseModel

* Rename to BackboneMixin

* remove nn.Module
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent be3d6c84
......@@ -15,6 +15,7 @@
# limitations under the License.
import collections
import gc
import inspect
import json
import os
import re
......@@ -932,6 +933,15 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class BackboneMixin:
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
......
......@@ -31,7 +31,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -844,7 +844,7 @@ class BitForImageClassification(BitPreTrainedModel):
""",
BIT_START_DOCSTRING,
)
class BitBackbone(BitPreTrainedModel):
class BitBackbone(BitPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
......
......@@ -27,7 +27,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN
from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from .configuration_maskformer_swin import MaskFormerSwinConfig
......@@ -837,7 +837,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
)
class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel):
class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
"""
MaskFormerSwin backbone, designed especially for the MaskFormer framework.
......
......@@ -28,7 +28,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -431,7 +431,7 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
""",
RESNET_START_DOCSTRING,
)
class ResNetBackbone(ResNetPreTrainedModel):
class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment