Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
86856993
Unverified
Commit
86856993
authored
Sep 21, 2022
by
Mishig Davaadorj
Committed by
GitHub
Sep 21, 2022
Browse files
Mv weights name consts to diffusers.utils (#605)
parent
f8100600
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
17 deletions
+16
-17
src/diffusers/modeling_flax_utils.py
src/diffusers/modeling_flax_utils.py
+9
-4
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-4
src/diffusers/onnx_utils.py
src/diffusers/onnx_utils.py
+1
-4
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+1
-3
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+3
-0
tests/test_pipelines.py
tests/test_pipelines.py
+1
-2
No files found.
src/diffusers/modeling_flax_utils.py
View file @
86856993
...
...
@@ -28,12 +28,17 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from
requests
import
HTTPError
from
.modeling_flax_pytorch_utils
import
convert_pytorch_state_dict_to_flax
from
.modeling_utils
import
WEIGHTS_NAME
,
load_state_dict
from
.utils
import
CONFIG_NAME
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
logging
from
.modeling_utils
import
load_state_dict
from
.utils
import
(
CONFIG_NAME
,
DIFFUSERS_CACHE
,
FLAX_WEIGHTS_NAME
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
WEIGHTS_NAME
,
logging
,
)
FLAX_WEIGHTS_NAME
=
"diffusion_flax_model.msgpack"
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/diffusers/modeling_utils.py
View file @
86856993
...
...
@@ -24,10 +24,7 @@ from huggingface_hub import hf_hub_download
from
huggingface_hub.utils
import
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
from
requests
import
HTTPError
from
.utils
import
CONFIG_NAME
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
logging
WEIGHTS_NAME
=
"diffusion_pytorch_model.bin"
from
.utils
import
CONFIG_NAME
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
WEIGHTS_NAME
,
logging
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/diffusers/onnx_utils.py
View file @
86856993
...
...
@@ -24,16 +24,13 @@ import numpy as np
from
huggingface_hub
import
hf_hub_download
from
.utils
import
is_onnx_available
,
logging
from
.utils
import
ONNX_WEIGHTS_NAME
,
is_onnx_available
,
logging
if
is_onnx_available
():
import
onnxruntime
as
ort
ONNX_WEIGHTS_NAME
=
"model.onnx"
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/diffusers/pipeline_utils.py
View file @
86856993
...
...
@@ -30,10 +30,8 @@ from PIL import Image
from
tqdm.auto
import
tqdm
from
.configuration_utils
import
ConfigMixin
from
.modeling_utils
import
WEIGHTS_NAME
from
.onnx_utils
import
ONNX_WEIGHTS_NAME
from
.schedulers.scheduling_utils
import
SCHEDULER_CONFIG_NAME
from
.utils
import
CONFIG_NAME
,
DIFFUSERS_CACHE
,
BaseOutput
,
logging
from
.utils
import
CONFIG_NAME
,
DIFFUSERS_CACHE
,
ONNX_WEIGHTS_NAME
,
WEIGHTS_NAME
,
BaseOutput
,
logging
INDEX_FILE
=
"diffusion_pytorch_model.bin"
...
...
src/diffusers/utils/__init__.py
View file @
86856993
...
...
@@ -47,6 +47,9 @@ default_cache_path = os.path.join(hf_cache_home, "diffusers")
CONFIG_NAME
=
"config.json"
WEIGHTS_NAME
=
"diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME
=
"diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME
=
"model.onnx"
HUGGINGFACE_CO_RESOLVE_ENDPOINT
=
"https://huggingface.co"
DIFFUSERS_CACHE
=
default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME
=
"diffusers_modules"
...
...
tests/test_pipelines.py
View file @
86856993
...
...
@@ -46,11 +46,10 @@ from diffusers import (
UNet2DModel
,
VQModel
,
)
from
diffusers.modeling_utils
import
WEIGHTS_NAME
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.schedulers.scheduling_utils
import
SCHEDULER_CONFIG_NAME
from
diffusers.testing_utils
import
floats_tensor
,
load_image
,
slow
,
torch_device
from
diffusers.utils
import
CONFIG_NAME
from
diffusers.utils
import
CONFIG_NAME
,
WEIGHTS_NAME
from
PIL
import
Image
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment