Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c39a511b
Unverified
Commit
c39a511b
authored
Nov 02, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 02, 2022
Browse files
[Loading] Ignore unneeded files (#1107)
* [Loading] Ignore unneeded files * up
parent
cbcd0512
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
2 deletions
+51
-2
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+11
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-0
tests/test_pipelines.py
tests/test_pipelines.py
+16
-0
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+20
-1
No files found.
src/diffusers/pipeline_flax_utils.py
View file @
c39a511b
...
@@ -302,10 +302,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -302,10 +302,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
+=
[
FLAX_WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
cls
.
config_name
]
allow_patterns
+=
[
FLAX_WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
cls
.
config_name
]
# make sure we don't download PyTorch weights
ignore_patterns
=
"*.bin"
if
cls
!=
FlaxDiffusionPipeline
:
if
cls
!=
FlaxDiffusionPipeline
:
requested_pipeline_class
=
cls
.
__name__
requested_pipeline_class
=
cls
.
__name__
else
:
else
:
requested_pipeline_class
=
config_dict
.
get
(
"_class_name"
,
cls
.
__name__
)
requested_pipeline_class
=
config_dict
.
get
(
"_class_name"
,
cls
.
__name__
)
requested_pipeline_class
=
(
requested_pipeline_class
if
requested_pipeline_class
.
startswith
(
"Flax"
)
else
"Flax"
+
requested_pipeline_class
)
user_agent
=
{
"pipeline_class"
:
requested_pipeline_class
}
user_agent
=
{
"pipeline_class"
:
requested_pipeline_class
}
user_agent
=
http_user_agent
(
user_agent
)
user_agent
=
http_user_agent
(
user_agent
)
...
@@ -319,6 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -319,6 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
)
)
else
:
else
:
...
@@ -337,7 +347,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -337,7 +347,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
if
config_dict
[
"_class_name"
].
startswith
(
"Flax"
)
if
config_dict
[
"_class_name"
].
startswith
(
"Flax"
)
else
"Flax"
+
config_dict
[
"_class_name"
]
else
"Flax"
+
config_dict
[
"_class_name"
]
)
)
pipeline_class
=
getattr
(
diffusers_module
,
config_dict
[
"_
class_name
"
]
)
pipeline_class
=
getattr
(
diffusers_module
,
class_name
)
# some modules can be passed directly to the init
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# in this case they are already instantiated in `kwargs`
...
...
src/diffusers/pipeline_utils.py
View file @
c39a511b
...
@@ -395,6 +395,9 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -395,6 +395,9 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
+=
[
WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
ONNX_WEIGHTS_NAME
,
cls
.
config_name
]
allow_patterns
+=
[
WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
ONNX_WEIGHTS_NAME
,
cls
.
config_name
]
# make sure we don't download flax weights
ignore_patterns
=
"*.msgpack"
if
custom_pipeline
is
not
None
:
if
custom_pipeline
is
not
None
:
allow_patterns
+=
[
CUSTOM_PIPELINE_FILE_NAME
]
allow_patterns
+=
[
CUSTOM_PIPELINE_FILE_NAME
]
...
@@ -417,6 +420,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -417,6 +420,7 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
)
)
else
:
else
:
...
...
tests/test_pipelines.py
View file @
c39a511b
...
@@ -73,6 +73,22 @@ def test_progress_bar(capsys):
...
@@ -73,6 +73,22 @@ def test_progress_bar(capsys):
assert
captured
.
err
==
""
,
"Progress bar should be disabled"
assert
captured
.
err
==
""
,
"Progress bar should be disabled"
class
DownloadTests
(
unittest
.
TestCase
):
def
test_download_only_pytorch
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
# pipeline has Flax weights
_
=
DiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
,
cache_dir
=
tmpdirname
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
os
.
path
.
join
(
tmpdirname
,
os
.
listdir
(
tmpdirname
)[
0
],
"snapshots"
))]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a flax file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert
not
any
(
f
.
endswith
(
".msgpack"
)
for
f
in
files
)
class
CustomPipelineTests
(
unittest
.
TestCase
):
class
CustomPipelineTests
(
unittest
.
TestCase
):
def
test_load_custom_pipeline
(
self
):
def
test_load_custom_pipeline
(
self
):
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
...
...
tests/test_pipelines_flax.py
View file @
c39a511b
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -24,12 +26,29 @@ from diffusers.utils.testing_utils import require_flax, slow
...
@@ -24,12 +26,29 @@ from diffusers.utils.testing_utils import require_flax, slow
if
is_flax_available
():
if
is_flax_available
():
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
diffusers
import
FlaxDDIMScheduler
,
FlaxStableDiffusionPipeline
from
diffusers
import
FlaxDDIMScheduler
,
FlaxDiffusionPipeline
,
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
from
jax
import
pmap
from
jax
import
pmap
@
require_flax
class
DownloadTests
(
unittest
.
TestCase
):
def
test_download_only_pytorch
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
# pipeline has Flax weights
_
=
FlaxDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
,
cache_dir
=
tmpdirname
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
os
.
path
.
join
(
tmpdirname
,
os
.
listdir
(
tmpdirname
)[
0
],
"snapshots"
))]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a PyTorch file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
assert
not
any
(
f
.
endswith
(
".bin"
)
for
f
in
files
)
@
slow
@
slow
@
require_flax
@
require_flax
class
FlaxPipelineTests
(
unittest
.
TestCase
):
class
FlaxPipelineTests
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment