Unverified Commit 4d0f8c05 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add `accelerate` support for ViLT (#18683)

parent 9393f966
...@@ -491,7 +491,7 @@ class ViltLayer(nn.Module): ...@@ -491,7 +491,7 @@ class ViltLayer(nn.Module):
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection # first residual connection
hidden_states = attention_output + hidden_states hidden_states = attention_output + hidden_states.to(attention_output.device)
# in ViLT, layernorm is also applied after self-attention # in ViLT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states) layer_output = self.layernorm_after(hidden_states)
...@@ -573,6 +573,7 @@ class ViltPreTrainedModel(PreTrainedModel): ...@@ -573,6 +573,7 @@ class ViltPreTrainedModel(PreTrainedModel):
config_class = ViltConfig config_class = ViltConfig
base_model_prefix = "vilt" base_model_prefix = "vilt"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ViltSelfAttention"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -772,7 +772,6 @@ class CaptureStd: ...@@ -772,7 +772,6 @@ class CaptureStd:
```""" ```"""
def __init__(self, out=True, err=True, replay=True): def __init__(self, out=True, err=True, replay=True):
self.replay = replay self.replay = replay
if out: if out:
...@@ -1122,7 +1121,6 @@ class TestCasePlus(unittest.TestCase): ...@@ -1122,7 +1121,6 @@ class TestCasePlus(unittest.TestCase):
tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
""" """
if tmp_dir is not None: if tmp_dir is not None:
# defining the most likely desired behavior for when a custom path is provided. # defining the most likely desired behavior for when a custom path is provided.
# this most likely indicates the debug mode where we want an easily locatable dir that: # this most likely indicates the debug mode where we want an easily locatable dir that:
# 1. gets cleared out before the test (if it already exists) # 1. gets cleared out before the test (if it already exists)
...@@ -1200,7 +1198,6 @@ class TestCasePlus(unittest.TestCase): ...@@ -1200,7 +1198,6 @@ class TestCasePlus(unittest.TestCase):
return max_rss return max_rss
def tearDown(self): def tearDown(self):
# get_auto_remove_tmp_dir feature: remove registered temp dirs # get_auto_remove_tmp_dir feature: remove registered temp dirs
for path in self.teardown_tmp_dirs: for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True) shutil.rmtree(path, ignore_errors=True)
...@@ -1472,7 +1469,6 @@ async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=Fals ...@@ -1472,7 +1469,6 @@ async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=Fals
def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
result = loop.run_until_complete( result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
......
...@@ -215,7 +215,6 @@ class ViltModelTester: ...@@ -215,7 +215,6 @@ class ViltModelTester:
@require_torch @require_torch
class ViltModelTest(ModelTesterMixin, unittest.TestCase): class ViltModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
ViltModel, ViltModel,
...@@ -512,7 +511,6 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -512,7 +511,6 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
class ViltForImagesAndTextClassificationModelTest(ViltModelTest, unittest.TestCase): class ViltForImagesAndTextClassificationModelTest(ViltModelTest, unittest.TestCase):
all_model_classes = (ViltForImagesAndTextClassification,) if is_torch_available() else () all_model_classes = (ViltForImagesAndTextClassification,) if is_torch_available() else ()
def setUp(self): def setUp(self):
......
...@@ -2307,6 +2307,7 @@ class ModelTesterMixin: ...@@ -2307,6 +2307,7 @@ class ModelTesterMixin:
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
...@@ -2324,6 +2325,7 @@ class ModelTesterMixin: ...@@ -2324,6 +2325,7 @@ class ModelTesterMixin:
) )
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
...@@ -2340,6 +2342,8 @@ class ModelTesterMixin: ...@@ -2340,6 +2342,8 @@ class ModelTesterMixin:
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
...@@ -2355,6 +2359,8 @@ class ModelTesterMixin: ...@@ -2355,6 +2359,8 @@ class ModelTesterMixin:
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
...@@ -2371,6 +2377,8 @@ class ModelTesterMixin: ...@@ -2371,6 +2377,8 @@ class ModelTesterMixin:
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
...@@ -2386,6 +2394,8 @@ class ModelTesterMixin: ...@@ -2386,6 +2394,8 @@ class ModelTesterMixin:
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment