Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
c39a511b
Unverified
Commit
c39a511b
authored
Nov 02, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 02, 2022
Browse files
[Loading] Ignore unneeded files (#1107)
* [Loading] Ignore unneeded files * up
parent
cbcd0512
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
2 deletions
+51
-2
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+11
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-0
tests/test_pipelines.py
tests/test_pipelines.py
+16
-0
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+20
-1
No files found.
src/diffusers/pipeline_flax_utils.py
View file @
c39a511b
...
@@ -302,10 +302,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -302,10 +302,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
+=
[
FLAX_WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
cls
.
config_name
]
allow_patterns
+=
[
FLAX_WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
cls
.
config_name
]
# make sure we don't download PyTorch weights
ignore_patterns
=
"*.bin"
if
cls
!=
FlaxDiffusionPipeline
:
if
cls
!=
FlaxDiffusionPipeline
:
requested_pipeline_class
=
cls
.
__name__
requested_pipeline_class
=
cls
.
__name__
else
:
else
:
requested_pipeline_class
=
config_dict
.
get
(
"_class_name"
,
cls
.
__name__
)
requested_pipeline_class
=
config_dict
.
get
(
"_class_name"
,
cls
.
__name__
)
requested_pipeline_class
=
(
requested_pipeline_class
if
requested_pipeline_class
.
startswith
(
"Flax"
)
else
"Flax"
+
requested_pipeline_class
)
user_agent
=
{
"pipeline_class"
:
requested_pipeline_class
}
user_agent
=
{
"pipeline_class"
:
requested_pipeline_class
}
user_agent
=
http_user_agent
(
user_agent
)
user_agent
=
http_user_agent
(
user_agent
)
...
@@ -319,6 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -319,6 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
)
)
else
:
else
:
...
@@ -337,7 +347,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -337,7 +347,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
if
config_dict
[
"_class_name"
].
startswith
(
"Flax"
)
if
config_dict
[
"_class_name"
].
startswith
(
"Flax"
)
else
"Flax"
+
config_dict
[
"_class_name"
]
else
"Flax"
+
config_dict
[
"_class_name"
]
)
)
pipeline_class
=
getattr
(
diffusers_module
,
config_dict
[
"_
class_name
"
]
)
pipeline_class
=
getattr
(
diffusers_module
,
class_name
)
# some modules can be passed directly to the init
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# in this case they are already instantiated in `kwargs`
...
...
src/diffusers/pipeline_utils.py
View file @
c39a511b
...
@@ -395,6 +395,9 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -395,6 +395,9 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
=
[
os
.
path
.
join
(
k
,
"*"
)
for
k
in
folder_names
]
allow_patterns
+=
[
WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
ONNX_WEIGHTS_NAME
,
cls
.
config_name
]
allow_patterns
+=
[
WEIGHTS_NAME
,
SCHEDULER_CONFIG_NAME
,
CONFIG_NAME
,
ONNX_WEIGHTS_NAME
,
cls
.
config_name
]
# make sure we don't download flax weights
ignore_patterns
=
"*.msgpack"
if
custom_pipeline
is
not
None
:
if
custom_pipeline
is
not
None
:
allow_patterns
+=
[
CUSTOM_PIPELINE_FILE_NAME
]
allow_patterns
+=
[
CUSTOM_PIPELINE_FILE_NAME
]
...
@@ -417,6 +420,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -417,6 +420,7 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
)
)
else
:
else
:
...
...
tests/test_pipelines.py
View file @
c39a511b
...
@@ -73,6 +73,22 @@ def test_progress_bar(capsys):
...
@@ -73,6 +73,22 @@ def test_progress_bar(capsys):
assert
captured
.
err
==
""
,
"Progress bar should be disabled"
assert
captured
.
err
==
""
,
"Progress bar should be disabled"
class
DownloadTests
(
unittest
.
TestCase
):
def
test_download_only_pytorch
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
# pipeline has Flax weights
_
=
DiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
,
cache_dir
=
tmpdirname
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
os
.
path
.
join
(
tmpdirname
,
os
.
listdir
(
tmpdirname
)[
0
],
"snapshots"
))]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a flax file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert
not
any
(
f
.
endswith
(
".msgpack"
)
for
f
in
files
)
class
CustomPipelineTests
(
unittest
.
TestCase
):
class
CustomPipelineTests
(
unittest
.
TestCase
):
def
test_load_custom_pipeline
(
self
):
def
test_load_custom_pipeline
(
self
):
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
...
...
tests/test_pipelines_flax.py
View file @
c39a511b
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -24,12 +26,29 @@ from diffusers.utils.testing_utils import require_flax, slow
...
@@ -24,12 +26,29 @@ from diffusers.utils.testing_utils import require_flax, slow
if
is_flax_available
():
if
is_flax_available
():
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
diffusers
import
FlaxDDIMScheduler
,
FlaxStableDiffusionPipeline
from
diffusers
import
FlaxDDIMScheduler
,
FlaxDiffusionPipeline
,
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
from
jax
import
pmap
from
jax
import
pmap
@
require_flax
class
DownloadTests
(
unittest
.
TestCase
):
def
test_download_only_pytorch
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
# pipeline has Flax weights
_
=
FlaxDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
,
cache_dir
=
tmpdirname
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
os
.
path
.
join
(
tmpdirname
,
os
.
listdir
(
tmpdirname
)[
0
],
"snapshots"
))]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a PyTorch file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
assert
not
any
(
f
.
endswith
(
".bin"
)
for
f
in
files
)
@
slow
@
slow
@
require_flax
@
require_flax
class
FlaxPipelineTests
(
unittest
.
TestCase
):
class
FlaxPipelineTests
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment