Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
34c90dbb
Unverified
Commit
34c90dbb
authored
Mar 28, 2024
by
YiYi Xu
Committed by
GitHub
Mar 29, 2024
Browse files
fix OOM for test_vae_tiling (#7510)
use float16 and add torch.no_grad()
parent
e49c04d5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
8 deletions
+12
-8
tests/models/autoencoders/test_models_vae.py
tests/models/autoencoders/test_models_vae.py
+8
-5
tests/pipelines/test_pipelines_common.py
tests/pipelines/test_pipelines_common.py
+4
-3
No files found.
tests/models/autoencoders/test_models_vae.py
View file @
34c90dbb
...
@@ -1118,8 +1118,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
...
@@ -1118,8 +1118,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
assert
torch_all_close
(
actual_output
,
expected_output
,
atol
=
5e-3
)
assert
torch_all_close
(
actual_output
,
expected_output
,
atol
=
5e-3
)
def
test_vae_tiling
(
self
):
def
test_vae_tiling
(
self
):
vae
=
ConsistencyDecoderVAE
.
from_pretrained
(
"openai/consistency-decoder"
)
vae
=
ConsistencyDecoderVAE
.
from_pretrained
(
"openai/consistency-decoder"
,
torch_dtype
=
torch
.
float16
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
vae
=
vae
,
safety_checker
=
None
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
vae
=
vae
,
safety_checker
=
None
,
torch_dtype
=
torch
.
float16
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -1143,6 +1145,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
...
@@ -1143,6 +1145,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
# test that tiled decode works with various shapes
# test that tiled decode works with various shapes
shapes
=
[(
1
,
4
,
73
,
97
),
(
1
,
4
,
97
,
73
),
(
1
,
4
,
49
,
65
),
(
1
,
4
,
65
,
49
)]
shapes
=
[(
1
,
4
,
73
,
97
),
(
1
,
4
,
97
,
73
),
(
1
,
4
,
49
,
65
),
(
1
,
4
,
65
,
49
)]
for
shape
in
shapes
:
with
torch
.
no_grad
():
image
=
torch
.
zeros
(
shape
,
device
=
torch_device
)
for
shape
in
shapes
:
pipe
.
vae
.
decode
(
image
)
image
=
torch
.
zeros
(
shape
,
device
=
torch_device
)
pipe
.
vae
.
decode
(
image
)
tests/pipelines/test_pipelines_common.py
View file @
34c90dbb
...
@@ -124,9 +124,10 @@ class SDFunctionTesterMixin:
...
@@ -124,9 +124,10 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes
# test that tiled decode works with various shapes
shapes
=
[(
1
,
4
,
73
,
97
),
(
1
,
4
,
97
,
73
),
(
1
,
4
,
49
,
65
),
(
1
,
4
,
65
,
49
)]
shapes
=
[(
1
,
4
,
73
,
97
),
(
1
,
4
,
97
,
73
),
(
1
,
4
,
49
,
65
),
(
1
,
4
,
65
,
49
)]
for
shape
in
shapes
:
with
torch
.
no_grad
():
zeros
=
torch
.
zeros
(
shape
).
to
(
torch_device
)
for
shape
in
shapes
:
pipe
.
vae
.
decode
(
zeros
)
zeros
=
torch
.
zeros
(
shape
).
to
(
torch_device
)
pipe
.
vae
.
decode
(
zeros
)
def
test_freeu_enabled
(
self
):
def
test_freeu_enabled
(
self
):
components
=
self
.
get_dummy_components
()
components
=
self
.
get_dummy_components
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment