Unverified Commit 30b45320 authored by Jacky Lee's avatar Jacky Lee Committed by GitHub
Browse files

Enable multi-device for some models (#30207)



* feat: multidevice for resnet

* feat: yes! resnet

* fix: compare all elements in tuple

* feat: support for regnet

* feat: support for convnextv2

* feat: support for bit

* feat: support for cvt

* feat: add support for focalnet

* feat: support for yolos

* feat: support for glpn

* feat: support for imagegpt

* feat: support for levit

* feat: support for mgp_str

* feat: support for mobilnet_v1

* feat: support for mobilnet_v2

* feat: support for mobilevit

* feat: support for mobilevitv2

* feat: support for poolformer

* fix: copies

* fix: code quality check

* update: upstream changes from main

* fix: consistency check

* feat: support for sam

* feat: support for switchformer

* feat: support for swin

* feat: support for swinv2

* feat: support for timesformer

* feat: suport for trocr

* feat: support for upernet

* fix: check copies

* update: rerun CI

* update: rerun again, maybe

* update: one more rerun

---------
Co-authored-by: default avatarJacky Lee <jackylee328@gmail.com>
parent ecfe9be7
......@@ -884,6 +884,7 @@ class SwinPreTrainedModel(PreTrainedModel):
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["SwinStage"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -939,6 +939,7 @@ class Swinv2PreTrainedModel(PreTrainedModel):
base_model_prefix = "swinv2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Swinv2Stage"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -472,6 +472,7 @@ class TimesformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "timesformer"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["TimesformerLayer"]
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Conv2d)):
......
......@@ -407,6 +407,7 @@ class TrOCRPreTrainedModel(PreTrainedModel):
config_class = TrOCRConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["TrOCRDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......
......@@ -293,6 +293,7 @@ class UperNetPreTrainedModel(PreTrainedModel):
config_class = UperNetConfig
main_input_name = "pixel_values"
_no_split_modules = []
def _init_weights(self, module):
if isinstance(module, UperNetPreTrainedModel):
......
......@@ -533,6 +533,7 @@ class YolosPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
......
......@@ -2907,7 +2907,10 @@ class ModelTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_accelerate
@mark.accelerate_tests
......@@ -2939,7 +2942,10 @@ class ModelTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_accelerate
@mark.accelerate_tests
......@@ -2975,7 +2981,10 @@ class ModelTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_accelerate
@mark.accelerate_tests
......@@ -3011,7 +3020,10 @@ class ModelTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
def test_problem_types(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment