Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
22b9cb08
Unverified
Commit
22b9cb08
authored
Dec 02, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 02, 2022
Browse files
[From pretrained] Allow returning local path (#1450)
Allow returning local path
parent
25f850a2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
28 deletions
+62
-28
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+33
-28
tests/test_pipelines.py
tests/test_pipelines.py
+29
-0
No files found.
src/diffusers/pipeline_utils.py
View file @
22b9cb08
...
@@ -377,7 +377,8 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -377,7 +377,8 @@ class DiffusionPipeline(ConfigMixin):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
setting this argument to `True` will raise an error.
return_cached_folder (`bool`, *optional*, defaults to `False`):
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
kwargs (remaining dictionary of keyword arguments, *optional*):
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
specific pipeline class. The overwritten components are then directly passed to the pipelines
...
@@ -430,33 +431,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -430,33 +431,7 @@ class DiffusionPipeline(ConfigMixin):
sess_options
=
kwargs
.
pop
(
"sess_options"
,
None
)
sess_options
=
kwargs
.
pop
(
"sess_options"
,
None
)
device_map
=
kwargs
.
pop
(
"device_map"
,
None
)
device_map
=
kwargs
.
pop
(
"device_map"
,
None
)
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
_LOW_CPU_MEM_USAGE_DEFAULT
)
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
_LOW_CPU_MEM_USAGE_DEFAULT
)
return_cached_folder
=
kwargs
.
pop
(
"return_cached_folder"
,
False
)
if
low_cpu_mem_usage
and
not
is_accelerate_available
():
low_cpu_mem_usage
=
False
logger
.
warning
(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with:
\n
```
\n
pip"
" install accelerate
\n
```
\n
."
)
if
device_map
is
not
None
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if
low_cpu_mem_usage
is
True
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if
low_cpu_mem_usage
is
False
and
device_map
is
not
None
:
raise
ValueError
(
f
"You cannot set `low_cpu_mem_usage` to False while using device_map=
{
device_map
}
for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
...
@@ -585,6 +560,33 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -585,6 +560,33 @@ class DiffusionPipeline(ConfigMixin):
f
"Keyword arguments
{
unused_kwargs
}
are not expected by
{
pipeline_class
.
__name__
}
and will be ignored."
f
"Keyword arguments
{
unused_kwargs
}
are not expected by
{
pipeline_class
.
__name__
}
and will be ignored."
)
)
if
low_cpu_mem_usage
and
not
is_accelerate_available
():
low_cpu_mem_usage
=
False
logger
.
warning
(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with:
\n
```
\n
pip"
" install accelerate
\n
```
\n
."
)
if
device_map
is
not
None
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if
low_cpu_mem_usage
is
True
and
not
is_torch_version
(
">="
,
"1.9.0"
):
raise
NotImplementedError
(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if
low_cpu_mem_usage
is
False
and
device_map
is
not
None
:
raise
ValueError
(
f
"You cannot set `low_cpu_mem_usage` to False while using device_map=
{
device_map
}
for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# import it here to avoid circular import
# import it here to avoid circular import
from
diffusers
import
pipelines
from
diffusers
import
pipelines
...
@@ -704,6 +706,9 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -704,6 +706,9 @@ class DiffusionPipeline(ConfigMixin):
# 5. Instantiate the pipeline
# 5. Instantiate the pipeline
model
=
pipeline_class
(
**
init_kwargs
)
model
=
pipeline_class
(
**
init_kwargs
)
if
return_cached_folder
:
return
model
,
cached_folder
return
model
return
model
@
staticmethod
@
staticmethod
...
...
tests/test_pipelines.py
View file @
22b9cb08
...
@@ -95,6 +95,35 @@ class DownloadTests(unittest.TestCase):
...
@@ -95,6 +95,35 @@ class DownloadTests(unittest.TestCase):
# We need to never convert this tiny model to safetensors for this test to pass
# We need to never convert this tiny model to safetensors for this test to pass
assert
not
any
(
f
.
endswith
(
".safetensors"
)
for
f
in
files
)
assert
not
any
(
f
.
endswith
(
".safetensors"
)
for
f
in
files
)
def
test_returned_cached_folder
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
_
,
local_path
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
,
return_cached_folder
=
True
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
local_path
)
pipe
=
pipe
.
to
(
torch_device
)
pipe_2
=
pipe
.
to
(
torch_device
)
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
def
test_download_safetensors
(
self
):
def
test_download_safetensors
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
# pipeline has Flax weights
# pipeline has Flax weights
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment