Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b33bd91f
Unverified
Commit
b33bd91f
authored
Mar 21, 2023
by
1lint
Committed by
GitHub
Mar 21, 2023
Browse files
Add option to set dtype in pipeline.to() method (#2317)
add test_to_dtype to check pipe.to(fp16)
parent
1fcf279d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
7 deletions
+21
-7
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+8
-3
tests/test_pipelines_common.py
tests/test_pipelines_common.py
+13
-4
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
b33bd91f
...
...
@@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin):
save_method
(
os
.
path
.
join
(
save_directory
,
pipeline_component_name
),
**
save_kwargs
)
def
to
(
self
,
torch_device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
silence_dtype_warnings
:
bool
=
False
):
if
torch_device
is
None
:
def
to
(
self
,
torch_device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
torch_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
silence_dtype_warnings
:
bool
=
False
,
):
if
torch_device
is
None
and
torch_dtype
is
None
:
return
self
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
...
...
@@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin):
for
name
in
module_names
.
keys
():
module
=
getattr
(
self
,
name
)
if
isinstance
(
module
,
torch
.
nn
.
Module
):
module
.
to
(
torch_device
,
torch_dtype
)
if
(
module
.
dtype
==
torch
.
float16
and
str
(
torch_device
)
in
[
"cpu"
]
...
...
@@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin):
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
module
.
to
(
torch_device
)
return
self
@
property
...
...
tests/test_pipelines_common.py
View file @
b33bd91f
...
...
@@ -344,11 +344,8 @@ class PipelineTesterMixin:
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
for
name
,
module
in
components
.
items
():
if
hasattr
(
module
,
"half"
):
components
[
name
]
=
module
.
half
()
pipe_fp16
=
self
.
pipeline_class
(
**
components
)
pipe_fp16
.
to
(
torch_device
)
pipe_fp16
.
to
(
torch_device
,
torch
.
float16
)
pipe_fp16
.
set_progress_bar_config
(
disable
=
None
)
output
=
pipe
(
**
self
.
get_dummy_inputs
(
torch_device
))[
0
]
...
...
@@ -447,6 +444,18 @@ class PipelineTesterMixin:
output_cuda
=
pipe
(
**
self
.
get_dummy_inputs
(
"cuda"
))[
0
]
self
.
assertTrue
(
np
.
isnan
(
output_cuda
).
sum
()
==
0
)
def
test_to_dtype
(
self
):
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
model_dtypes
=
[
component
.
dtype
for
component
in
components
.
values
()
if
hasattr
(
component
,
"dtype"
)]
self
.
assertTrue
(
all
(
dtype
==
torch
.
float32
for
dtype
in
model_dtypes
))
pipe
.
to
(
torch_dtype
=
torch
.
float16
)
model_dtypes
=
[
component
.
dtype
for
component
in
components
.
values
()
if
hasattr
(
component
,
"dtype"
)]
self
.
assertTrue
(
all
(
dtype
==
torch
.
float16
for
dtype
in
model_dtypes
))
def
test_attention_slicing_forward_pass
(
self
):
self
.
_test_attention_slicing_forward_pass
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment