Unverified Commit 30b45320 authored by Jacky Lee's avatar Jacky Lee Committed by GitHub
Browse files

Enable multi-device for some models (#30207)



* feat: multidevice for resnet

* feat: yes! resnet

* fix: compare all elements in tuple

* feat: support for regnet

* feat: support for convnextv2

* feat: support for bit

* feat: support for cvt

* feat: add support for focalnet

* feat: support for yolos

* feat: support for glpn

* feat: support for imagegpt

* feat: support for levit

* feat: support for mgp_str

* feat: support for mobilnet_v1

* feat: support for mobilnet_v2

* feat: support for mobilevit

* feat: support for mobilevitv2

* feat: support for poolformer

* fix: copies

* fix: code quality check

* update: upstream changes from main

* fix: consistency check

* feat: support for sam

* feat: support for switchformer

* feat: support for swin

* feat: support for swinv2

* feat: support for timesformer

* feat: suport for trocr

* feat: support for upernet

* fix: check copies

* update: rerun CI

* update: rerun again, maybe

* update: one more rerun

---------
Co-authored-by: default avatarJacky Lee <jackylee328@gmail.com>
parent ecfe9be7
...@@ -658,6 +658,7 @@ class BitPreTrainedModel(PreTrainedModel): ...@@ -658,6 +658,7 @@ class BitPreTrainedModel(PreTrainedModel):
config_class = BitConfig config_class = BitConfig
base_model_prefix = "bit" base_model_prefix = "bit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["BitEmbeddings"]
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
......
...@@ -280,6 +280,7 @@ class ConvNextPreTrainedModel(PreTrainedModel): ...@@ -280,6 +280,7 @@ class ConvNextPreTrainedModel(PreTrainedModel):
config_class = ConvNextConfig config_class = ConvNextConfig
base_model_prefix = "convnext" base_model_prefix = "convnext"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["ConvNextLayer"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -301,6 +301,7 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): ...@@ -301,6 +301,7 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
config_class = ConvNextV2Config config_class = ConvNextV2Config
base_model_prefix = "convnextv2" base_model_prefix = "convnextv2"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["ConvNextV2Layer"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -534,6 +534,7 @@ class CvtPreTrainedModel(PreTrainedModel): ...@@ -534,6 +534,7 @@ class CvtPreTrainedModel(PreTrainedModel):
config_class = CvtConfig config_class = CvtConfig
base_model_prefix = "cvt" base_model_prefix = "cvt"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["CvtLayer"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -809,6 +809,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel): ...@@ -809,6 +809,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
base_model_prefix = "swin" base_model_prefix = "swin"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["DonutSwinStage"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -636,6 +636,7 @@ class FocalNetPreTrainedModel(PreTrainedModel): ...@@ -636,6 +636,7 @@ class FocalNetPreTrainedModel(PreTrainedModel):
base_model_prefix = "focalnet" base_model_prefix = "focalnet"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["FocalNetStage"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -426,6 +426,7 @@ class GLPNPreTrainedModel(PreTrainedModel): ...@@ -426,6 +426,7 @@ class GLPNPreTrainedModel(PreTrainedModel):
config_class = GLPNConfig config_class = GLPNConfig
base_model_prefix = "glpn" base_model_prefix = "glpn"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = []
# Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights # Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -491,6 +491,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel): ...@@ -491,6 +491,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
main_input_name = "input_ids" main_input_name = "input_ids"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ImageGPTBlock"]
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -491,6 +491,7 @@ class LevitPreTrainedModel(PreTrainedModel): ...@@ -491,6 +491,7 @@ class LevitPreTrainedModel(PreTrainedModel):
config_class = LevitConfig config_class = LevitConfig
base_model_prefix = "levit" base_model_prefix = "levit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["LevitResidualLayer"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -735,6 +735,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): ...@@ -735,6 +735,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["MaskFormerSwinStage"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -317,6 +317,7 @@ class MgpstrPreTrainedModel(PreTrainedModel): ...@@ -317,6 +317,7 @@ class MgpstrPreTrainedModel(PreTrainedModel):
config_class = MgpstrConfig config_class = MgpstrConfig
base_model_prefix = "mgp_str" base_model_prefix = "mgp_str"
_no_split_modules = []
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -254,6 +254,7 @@ class MobileNetV1PreTrainedModel(PreTrainedModel): ...@@ -254,6 +254,7 @@ class MobileNetV1PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilenet_v1" base_model_prefix = "mobilenet_v1"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_no_split_modules = []
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -453,6 +453,7 @@ class MobileNetV2PreTrainedModel(PreTrainedModel): ...@@ -453,6 +453,7 @@ class MobileNetV2PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilenet_v2" base_model_prefix = "mobilenet_v2"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_no_split_modules = []
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -644,6 +644,7 @@ class MobileViTPreTrainedModel(PreTrainedModel): ...@@ -644,6 +644,7 @@ class MobileViTPreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilevit" base_model_prefix = "mobilevit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["MobileViTLayer"]
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -606,6 +606,7 @@ class MobileViTV2PreTrainedModel(PreTrainedModel): ...@@ -606,6 +606,7 @@ class MobileViTV2PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilevitv2" base_model_prefix = "mobilevitv2"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["MobileViTV2Layer"]
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -268,6 +268,7 @@ class PoolFormerPreTrainedModel(PreTrainedModel): ...@@ -268,6 +268,7 @@ class PoolFormerPreTrainedModel(PreTrainedModel):
config_class = PoolFormerConfig config_class = PoolFormerConfig
base_model_prefix = "poolformer" base_model_prefix = "poolformer"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["PoolFormerLayer"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -281,6 +281,7 @@ class RegNetPreTrainedModel(PreTrainedModel): ...@@ -281,6 +281,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
config_class = RegNetConfig config_class = RegNetConfig
base_model_prefix = "regnet" base_model_prefix = "regnet"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["RegNetYLayer"]
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights # Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -272,6 +272,7 @@ class ResNetPreTrainedModel(PreTrainedModel): ...@@ -272,6 +272,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
config_class = ResNetConfig config_class = ResNetConfig
base_model_prefix = "resnet" base_model_prefix = "resnet"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["ResNetConvLayer", "ResNetShortCut"]
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
......
...@@ -1074,6 +1074,7 @@ class SamPreTrainedModel(PreTrainedModel): ...@@ -1074,6 +1074,7 @@ class SamPreTrainedModel(PreTrainedModel):
config_class = SamConfig config_class = SamConfig
base_model_prefix = "sam" base_model_prefix = "sam"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = ["SamVisionAttention"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -428,6 +428,7 @@ class SwiftFormerPreTrainedModel(PreTrainedModel): ...@@ -428,6 +428,7 @@ class SwiftFormerPreTrainedModel(PreTrainedModel):
base_model_prefix = "swiftformer" base_model_prefix = "swiftformer"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["SwiftFormerEncoderBlock"]
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment