Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
22b9cb08
Unverified
Commit
22b9cb08
authored
Dec 02, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 02, 2022
Browse files
[From pretrained] Allow returning local path (#1450)
Allow returning local path
parent
25f850a2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
28 deletions
+62
-28
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+33
-28
tests/test_pipelines.py
tests/test_pipelines.py
+29
-0
No files found.
src/diffusers/pipeline_utils.py
View file @
22b9cb08
...
...
@@ -377,7 +377,8 @@ class DiffusionPipeline(ConfigMixin):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
return_cached_folder (`bool`, *optional*, defaults to `False`):
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
...
...
@@ -430,33 +431,7 @@ class DiffusionPipeline(ConfigMixin):
sess_options
=
kwargs
.
pop
(
"sess_options"
,
None
)
device_map
=
kwargs
.
pop
(
"device_map"
,
None
)
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
_LOW_CPU_MEM_USAGE_DEFAULT
)
if
low_cpu_mem_usage
and
not
is_accelerate_available
():
low_cpu_mem_usage
=
False
logger
.
warning
(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with:
\n
```
\n
pip"
" install accelerate
\n
```
\n
."
)
if
device_map
is
not
None
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if
low_cpu_mem_usage
is
True
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if
low_cpu_mem_usage
is
False
and
device_map
is
not
None
:
raise
ValueError
(
f
"You cannot set `low_cpu_mem_usage` to False while using device_map=
{
device_map
}
for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
return_cached_folder
=
kwargs
.
pop
(
"return_cached_folder"
,
False
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
...
...
@@ -585,6 +560,33 @@ class DiffusionPipeline(ConfigMixin):
f
"Keyword arguments
{
unused_kwargs
}
are not expected by
{
pipeline_class
.
__name__
}
and will be ignored."
)
if
low_cpu_mem_usage
and
not
is_accelerate_available
():
low_cpu_mem_usage
=
False
logger
.
warning
(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with:
\n
```
\n
pip"
" install accelerate
\n
```
\n
."
)
if
device_map
is
not
None
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if
low_cpu_mem_usage
is
True
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if
low_cpu_mem_usage
is
False
and
device_map
is
not
None
:
raise
ValueError
(
f
"You cannot set `low_cpu_mem_usage` to False while using device_map=
{
device_map
}
for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# import it here to avoid circular import
from
diffusers
import
pipelines
...
...
@@ -704,6 +706,9 @@ class DiffusionPipeline(ConfigMixin):
# 5. Instantiate the pipeline
model
=
pipeline_class
(
**
init_kwargs
)
if
return_cached_folder
:
return
model
,
cached_folder
return
model
@
staticmethod
...
...
tests/test_pipelines.py
View file @
22b9cb08
...
...
@@ -95,6 +95,35 @@ class DownloadTests(unittest.TestCase):
# We need to never convert this tiny model to safetensors for this test to pass
assert
not
any
(
f
.
endswith
(
".safetensors"
)
for
f
in
files
)
def
test_returned_cached_folder
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
_
,
local_path
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
,
return_cached_folder
=
True
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
local_path
)
pipe
=
pipe
.
to
(
torch_device
)
pipe_2
=
pipe
.
to
(
torch_device
)
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
def
test_download_safetensors
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
# pipeline has Flax weights
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment