Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4d0f8c05
Unverified
Commit
4d0f8c05
authored
Sep 22, 2022
by
Younes Belkada
Committed by
GitHub
Sep 22, 2022
Browse files
Add `accelerate` support for ViLT (#18683)
parent
9393f966
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
7 deletions
+12
-7
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vilt/modeling_vilt.py
+2
-1
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+0
-4
tests/models/vilt/test_modeling_vilt.py
tests/models/vilt/test_modeling_vilt.py
+0
-2
tests/test_modeling_common.py
tests/test_modeling_common.py
+10
-0
No files found.
src/transformers/models/vilt/modeling_vilt.py
View file @
4d0f8c05
...
...
@@ -491,7 +491,7 @@ class ViltLayer(nn.Module):
outputs
=
self_attention_outputs
[
1
:]
# add self attentions if we output attention weights
# first residual connection
hidden_states
=
attention_output
+
hidden_states
hidden_states
=
attention_output
+
hidden_states
.
to
(
attention_output
.
device
)
# in ViLT, layernorm is also applied after self-attention
layer_output
=
self
.
layernorm_after
(
hidden_states
)
...
...
@@ -573,6 +573,7 @@ class ViltPreTrainedModel(PreTrainedModel):
config_class
=
ViltConfig
base_model_prefix
=
"vilt"
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"ViltSelfAttention"
]
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
...
...
src/transformers/testing_utils.py
View file @
4d0f8c05
...
...
@@ -772,7 +772,6 @@ class CaptureStd:
```"""
def
__init__
(
self
,
out
=
True
,
err
=
True
,
replay
=
True
):
self
.
replay
=
replay
if
out
:
...
...
@@ -1122,7 +1121,6 @@ class TestCasePlus(unittest.TestCase):
tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
"""
if
tmp_dir
is
not
None
:
# defining the most likely desired behavior for when a custom path is provided.
# this most likely indicates the debug mode where we want an easily locatable dir that:
# 1. gets cleared out before the test (if it already exists)
...
...
@@ -1200,7 +1198,6 @@ class TestCasePlus(unittest.TestCase):
return
max_rss
def
tearDown
(
self
):
# get_auto_remove_tmp_dir feature: remove registered temp dirs
for
path
in
self
.
teardown_tmp_dirs
:
shutil
.
rmtree
(
path
,
ignore_errors
=
True
)
...
...
@@ -1472,7 +1469,6 @@ async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=Fals
def
execute_subprocess_async
(
cmd
,
env
=
None
,
stdin
=
None
,
timeout
=
180
,
quiet
=
False
,
echo
=
True
)
->
_RunOutput
:
loop
=
asyncio
.
get_event_loop
()
result
=
loop
.
run_until_complete
(
_stream_subprocess
(
cmd
,
env
=
env
,
stdin
=
stdin
,
timeout
=
timeout
,
quiet
=
quiet
,
echo
=
echo
)
...
...
tests/models/vilt/test_modeling_vilt.py
View file @
4d0f8c05
...
...
@@ -215,7 +215,6 @@ class ViltModelTester:
@
require_torch
class
ViltModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
ViltModel
,
...
...
@@ -512,7 +511,6 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase):
@
require_torch
class
ViltForImagesAndTextClassificationModelTest
(
ViltModelTest
,
unittest
.
TestCase
):
all_model_classes
=
(
ViltForImagesAndTextClassification
,)
if
is_torch_available
()
else
()
def
setUp
(
self
):
...
...
tests/test_modeling_common.py
View file @
4d0f8c05
...
...
@@ -2307,6 +2307,7 @@ class ModelTesterMixin:
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
...
...
@@ -2324,6 +2325,7 @@ class ModelTesterMixin:
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
...
...
@@ -2340,6 +2342,8 @@ class ModelTesterMixin:
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
...
...
@@ -2355,6 +2359,8 @@ class ModelTesterMixin:
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
"cpu"
})
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
...
...
@@ -2371,6 +2377,8 @@ class ModelTesterMixin:
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
...
...
@@ -2386,6 +2394,8 @@ class ModelTesterMixin:
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
1
})
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment