Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4bb07647
Unverified
Commit
4bb07647
authored
Nov 17, 2022
by
Younes Belkada
Committed by
GitHub
Nov 17, 2022
Browse files
refactor test (#20300)
- simplifies the devce checking test
parent
700e0cd6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
17 deletions
+2
-17
tests/mixed_int8/test_mixed_int8.py
tests/mixed_int8/test_mixed_int8.py
+2
-17
No files found.
tests/mixed_int8/test_mixed_int8.py
View file @
4bb07647
...
...
@@ -215,23 +215,8 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
self
.
model_name
,
load_in_8bit
=
True
,
max_memory
=
memory_mapping
,
device_map
=
"auto"
)
def
get_list_devices
(
model
):
list_devices
=
[]
for
_
,
module
in
model
.
named_children
():
if
len
(
list
(
module
.
children
()))
>
0
:
list_devices
.
extend
(
get_list_devices
(
module
))
else
:
# Do a try except since we can encounter Dropout modules that does not
# have any device set
try
:
list_devices
.
append
(
next
(
module
.
parameters
()).
device
.
index
)
except
BaseException
:
continue
return
list_devices
list_devices
=
get_list_devices
(
model_parallel
)
# Check that we have dispatched the model into 2 separate devices
self
.
assertTrue
((
1
in
list_devices
)
and
(
0
in
list_devices
))
# Check correct device map
self
.
assertEqual
(
set
(
model_parallel
.
hf_device_map
.
values
()),
{
0
,
1
})
# Check that inference pass works on the model
encoded_input
=
self
.
tokenizer
(
self
.
input_text
,
return_tensors
=
"pt"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment