"tests/models/vscode:/vscode.git/clone" did not exist on "476890e9aeb695c746efe093c6ab7440322c9077"
Unverified Commit 6eae3f78 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add `BackboneMixin` (#20660)



* add BackboneBaseModel

* add BackboneBaseModel

* Rename to BackboneMixin

* remove nn.Module
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent be3d6c84
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import collections import collections
import gc import gc
import inspect
import json import json
import os import os
import re import re
...@@ -932,6 +933,15 @@ class ModuleUtilsMixin: ...@@ -932,6 +933,15 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class BackboneMixin:
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r""" r"""
Base class for all models. Base class for all models.
......
...@@ -31,7 +31,7 @@ from ...modeling_outputs import ( ...@@ -31,7 +31,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention, ImageClassifierOutputWithNoAttention,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -844,7 +844,7 @@ class BitForImageClassification(BitPreTrainedModel): ...@@ -844,7 +844,7 @@ class BitForImageClassification(BitPreTrainedModel):
""", """,
BIT_START_DOCSTRING, BIT_START_DOCSTRING,
) )
class BitBackbone(BitPreTrainedModel): class BitBackbone(BitPreTrainedModel, BackboneMixin):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -27,7 +27,7 @@ from torch import Tensor, nn ...@@ -27,7 +27,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ModelOutput from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from .configuration_maskformer_swin import MaskFormerSwinConfig from .configuration_maskformer_swin import MaskFormerSwinConfig
...@@ -837,7 +837,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): ...@@ -837,7 +837,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
) )
class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel): class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
""" """
MaskFormerSwin backbone, designed especially for the MaskFormer framework. MaskFormerSwin backbone, designed especially for the MaskFormer framework.
......
...@@ -28,7 +28,7 @@ from ...modeling_outputs import ( ...@@ -28,7 +28,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention, ImageClassifierOutputWithNoAttention,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -431,7 +431,7 @@ class ResNetForImageClassification(ResNetPreTrainedModel): ...@@ -431,7 +431,7 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
""", """,
RESNET_START_DOCSTRING, RESNET_START_DOCSTRING,
) )
class ResNetBackbone(ResNetPreTrainedModel): class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment