Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
b33bd91f
Unverified
Commit
b33bd91f
authored
Mar 21, 2023
by
1lint
Committed by
GitHub
Mar 21, 2023
Browse files
Add option to set dtype in pipeline.to() method (#2317)
add test_to_dtype to check pipe.to(fp16)
parent
1fcf279d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
7 deletions
+21
-7
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+8
-3
tests/test_pipelines_common.py
tests/test_pipelines_common.py
+13
-4
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
b33bd91f
...
@@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin):
save_method
(
os
.
path
.
join
(
save_directory
,
pipeline_component_name
),
**
save_kwargs
)
save_method
(
os
.
path
.
join
(
save_directory
,
pipeline_component_name
),
**
save_kwargs
)
def
to
(
self
,
torch_device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
silence_dtype_warnings
:
bool
=
False
):
def
to
(
if
torch_device
is
None
:
self
,
torch_device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
torch_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
silence_dtype_warnings
:
bool
=
False
,
):
if
torch_device
is
None
and
torch_dtype
is
None
:
return
self
return
self
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
...
@@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin):
for
name
in
module_names
.
keys
():
for
name
in
module_names
.
keys
():
module
=
getattr
(
self
,
name
)
module
=
getattr
(
self
,
name
)
if
isinstance
(
module
,
torch
.
nn
.
Module
):
if
isinstance
(
module
,
torch
.
nn
.
Module
):
module
.
to
(
torch_device
,
torch_dtype
)
if
(
if
(
module
.
dtype
==
torch
.
float16
module
.
dtype
==
torch
.
float16
and
str
(
torch_device
)
in
[
"cpu"
]
and
str
(
torch_device
)
in
[
"cpu"
]
...
@@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin):
" support for`float16` operations on this device in PyTorch. Please, remove the"
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
)
module
.
to
(
torch_device
)
return
self
return
self
@
property
@
property
...
...
tests/test_pipelines_common.py
View file @
b33bd91f
...
@@ -344,11 +344,8 @@ class PipelineTesterMixin:
...
@@ -344,11 +344,8 @@ class PipelineTesterMixin:
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
for
name
,
module
in
components
.
items
():
if
hasattr
(
module
,
"half"
):
components
[
name
]
=
module
.
half
()
pipe_fp16
=
self
.
pipeline_class
(
**
components
)
pipe_fp16
=
self
.
pipeline_class
(
**
components
)
pipe_fp16
.
to
(
torch_device
)
pipe_fp16
.
to
(
torch_device
,
torch
.
float16
)
pipe_fp16
.
set_progress_bar_config
(
disable
=
None
)
pipe_fp16
.
set_progress_bar_config
(
disable
=
None
)
output
=
pipe
(
**
self
.
get_dummy_inputs
(
torch_device
))[
0
]
output
=
pipe
(
**
self
.
get_dummy_inputs
(
torch_device
))[
0
]
...
@@ -447,6 +444,18 @@ class PipelineTesterMixin:
...
@@ -447,6 +444,18 @@ class PipelineTesterMixin:
output_cuda
=
pipe
(
**
self
.
get_dummy_inputs
(
"cuda"
))[
0
]
output_cuda
=
pipe
(
**
self
.
get_dummy_inputs
(
"cuda"
))[
0
]
self
.
assertTrue
(
np
.
isnan
(
output_cuda
).
sum
()
==
0
)
self
.
assertTrue
(
np
.
isnan
(
output_cuda
).
sum
()
==
0
)
def
test_to_dtype
(
self
):
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
model_dtypes
=
[
component
.
dtype
for
component
in
components
.
values
()
if
hasattr
(
component
,
"dtype"
)]
self
.
assertTrue
(
all
(
dtype
==
torch
.
float32
for
dtype
in
model_dtypes
))
pipe
.
to
(
torch_dtype
=
torch
.
float16
)
model_dtypes
=
[
component
.
dtype
for
component
in
components
.
values
()
if
hasattr
(
component
,
"dtype"
)]
self
.
assertTrue
(
all
(
dtype
==
torch
.
float16
for
dtype
in
model_dtypes
))
def
test_attention_slicing_forward_pass
(
self
):
def
test_attention_slicing_forward_pass
(
self
):
self
.
_test_attention_slicing_forward_pass
()
self
.
_test_attention_slicing_forward_pass
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment