Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7f724a93
Unverified
Commit
7f724a93
authored
Mar 31, 2024
by
YiYi Xu
Committed by
GitHub
Apr 01, 2024
Browse files
fix the cpu offload tests (#7544)
fix
parent
9bef9f4b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
22 deletions
+27
-22
tests/pipelines/test_pipelines_common.py
tests/pipelines/test_pipelines_common.py
+27
-22
No files found.
tests/pipelines/test_pipelines_common.py
View file @
7f724a93
...
@@ -1144,20 +1144,24 @@ class PipelineTesterMixin:
...
@@ -1144,20 +1144,24 @@ class PipelineTesterMixin:
self
.
assertLess
(
self
.
assertLess
(
max_diff
,
expected_max_diff
,
"running CPU offloading 2nd time should not affect the inference results"
max_diff
,
expected_max_diff
,
"running CPU offloading 2nd time should not affect the inference results"
)
)
offloaded_modules
=
[
offloaded_modules
=
{
v
k
:
v
for
k
,
v
in
pipe
.
components
.
items
()
for
k
,
v
in
pipe
.
components
.
items
()
if
isinstance
(
v
,
torch
.
nn
.
Module
)
and
k
not
in
pipe
.
_exclude_from_cpu_offload
if
isinstance
(
v
,
torch
.
nn
.
Module
)
and
k
not
in
pipe
.
_exclude_from_cpu_offload
]
}
(
self
.
assertTrue
(
self
.
assertTrue
(
all
(
v
.
device
.
type
==
"cpu"
for
v
in
offloaded_modules
)),
all
(
v
.
device
.
type
==
"cpu"
for
v
in
offloaded_modules
.
values
(
)),
f
"Not offloaded:
{
[
v
for
v
in
offloaded_modules
if
v
.
device
.
type
!=
'cpu'
]
}
"
,
f
"Not offloaded:
{
[
k
for
k
,
v
in
offloaded_modules
.
items
()
if
v
.
device
.
type
!=
'cpu'
]
}
"
,
)
)
offloaded_modules_with_hooks
=
[
v
for
v
in
offloaded_modules
if
hasattr
(
v
,
"_hf_hook"
)]
offloaded_modules_with_incorrect_hooks
=
{}
(
for
k
,
v
in
offloaded_modules
.
items
():
self
.
assertTrue
(
all
(
isinstance
(
v
,
accelerate
.
hooks
.
CpuOffload
)
for
v
in
offloaded_modules_with_hooks
)),
if
hasattr
(
v
,
"_hf_hook"
)
and
not
isinstance
(
v
.
_hf_hook
,
accelerate
.
hooks
.
CpuOffload
):
f
"Not installed correct hook:
{
[
v
for
v
in
offloaded_modules_with_hooks
if
not
isinstance
(
v
,
accelerate
.
hooks
.
CpuOffload
)]
}
"
,
offloaded_modules_with_incorrect_hooks
[
k
]
=
type
(
v
.
_hf_hook
)
self
.
assertTrue
(
len
(
offloaded_modules_with_incorrect_hooks
)
==
0
,
f
"Not installed correct hook:
{
offloaded_modules_with_incorrect_hooks
}
"
,
)
)
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
@@ -1189,22 +1193,23 @@ class PipelineTesterMixin:
...
@@ -1189,22 +1193,23 @@ class PipelineTesterMixin:
self
.
assertLess
(
self
.
assertLess
(
max_diff
,
expected_max_diff
,
"running sequential offloading second time should have the inference results"
max_diff
,
expected_max_diff
,
"running sequential offloading second time should have the inference results"
)
)
offloaded_modules
=
[
offloaded_modules
=
{
v
k
:
v
for
k
,
v
in
pipe
.
components
.
items
()
for
k
,
v
in
pipe
.
components
.
items
()
if
isinstance
(
v
,
torch
.
nn
.
Module
)
and
k
not
in
pipe
.
_exclude_from_cpu_offload
if
isinstance
(
v
,
torch
.
nn
.
Module
)
and
k
not
in
pipe
.
_exclude_from_cpu_offload
]
}
(
self
.
assertTrue
(
self
.
assertTrue
(
all
(
v
.
device
.
type
==
"meta"
for
v
in
offloaded_modules
)),
all
(
v
.
device
.
type
==
"meta"
for
v
in
offloaded_modules
.
values
(
)),
f
"Not offloaded:
{
[
v
for
v
in
offloaded_modules
if
v
.
device
.
type
!=
'meta'
]
}
"
,
f
"Not offloaded:
{
[
k
for
k
,
v
in
offloaded_modules
.
items
()
if
v
.
device
.
type
!=
'meta'
]
}
"
,
)
)
offloaded_modules_with_incorrect_hooks
=
{}
for
k
,
v
in
offloaded_modules
.
items
():
if
hasattr
(
v
,
"_hf_hook"
)
and
not
isinstance
(
v
.
_hf_hook
,
accelerate
.
hooks
.
AlignDevicesHook
):
offloaded_modules_with_incorrect_hooks
[
k
]
=
type
(
v
.
_hf_hook
)
offloaded_modules_with_hooks
=
[
v
for
v
in
offloaded_modules
if
hasattr
(
v
,
"_hf_hook"
)]
self
.
assertTrue
(
(
len
(
offloaded_modules_with_incorrect_hooks
)
==
0
,
self
.
assertTrue
(
f
"Not installed correct hook:
{
offloaded_modules_with_incorrect_hooks
}
"
,
all
(
isinstance
(
v
,
accelerate
.
hooks
.
AlignDevicesHook
)
for
v
in
offloaded_modules_with_hooks
)
),
f
"Not installed correct hook:
{
[
v
for
v
in
offloaded_modules_with_hooks
if
not
isinstance
(
v
,
accelerate
.
hooks
.
AlignDevicesHook
)]
}
"
,
)
)
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment