Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
04c7c176
Unverified
Commit
04c7c176
authored
May 24, 2024
by
Fanli Lin
Committed by
GitHub
May 24, 2024
Browse files
[tests] make `test_model_parallelism` device-agnostic (#30844)
* enable on xpu * fix style * add comment and mps
parent
42d8dd87
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
3 deletions
+6
-3
tests/test_modeling_common.py
tests/test_modeling_common.py
+6
-3
No files found.
tests/test_modeling_common.py
View file @
04c7c176
...
@@ -76,6 +76,7 @@ from transformers.testing_utils import (
...
@@ -76,6 +76,7 @@ from transformers.testing_utils import (
require_safetensors
,
require_safetensors
,
require_torch
,
require_torch
,
require_torch_gpu
,
require_torch_gpu
,
require_torch_multi_accelerator
,
require_torch_multi_gpu
,
require_torch_multi_gpu
,
require_torch_sdpa
,
require_torch_sdpa
,
slow
,
slow
,
...
@@ -3009,8 +3010,11 @@ class ModelTesterMixin:
...
@@ -3009,8 +3010,11 @@ class ModelTesterMixin:
param_device
=
device_map
[
param_name
]
param_device
=
device_map
[
param_name
]
if
param_device
in
[
"cpu"
,
"disk"
]:
if
param_device
in
[
"cpu"
,
"disk"
]:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
"meta"
))
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
"meta"
))
elif
param_device
in
[
"mps"
]:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
"mps"
))
else
:
else
:
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
param_device
))
# when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
self
.
assertEqual
(
param
.
device
,
torch
.
device
(
f
"
{
torch_device
}
:
{
param_device
}
"
))
@
require_accelerate
@
require_accelerate
@
mark
.
accelerate_tests
@
mark
.
accelerate_tests
...
@@ -3129,7 +3133,7 @@ class ModelTesterMixin:
...
@@ -3129,7 +3133,7 @@ class ModelTesterMixin:
@
require_accelerate
@
require_accelerate
@
mark
.
accelerate_tests
@
mark
.
accelerate_tests
@
require_torch_multi_
gpu
@
require_torch_multi_
accelerator
def
test_model_parallelism
(
self
):
def
test_model_parallelism
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -3155,7 +3159,6 @@ class ModelTesterMixin:
...
@@ -3155,7 +3159,6 @@ class ModelTesterMixin:
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
new_model
=
model_class
.
from_pretrained
(
tmp_dir
,
device_map
=
"auto"
,
max_memory
=
max_memory
)
# Making sure part of the model will actually end up offloaded
# Making sure part of the model will actually end up offloaded
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
1
})
self
.
assertSetEqual
(
set
(
new_model
.
hf_device_map
.
values
()),
{
0
,
1
})
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment