Unverified Commit 86856993 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Mv weights name consts to diffusers.utils (#605)

parent f8100600
...@@ -28,12 +28,17 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R ...@@ -28,12 +28,17 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError from requests import HTTPError
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import WEIGHTS_NAME, load_state_dict from .modeling_utils import load_state_dict
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
logging,
)
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -24,10 +24,7 @@ from huggingface_hub import hf_hub_download ...@@ -24,10 +24,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -24,16 +24,13 @@ import numpy as np ...@@ -24,16 +24,13 @@ import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from .utils import is_onnx_available, logging from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
if is_onnx_available(): if is_onnx_available():
import onnxruntime as ort import onnxruntime as ort
ONNX_WEIGHTS_NAME = "model.onnx"
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -30,10 +30,8 @@ from PIL import Image ...@@ -30,10 +30,8 @@ from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .modeling_utils import WEIGHTS_NAME
from .onnx_utils import ONNX_WEIGHTS_NAME
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin" INDEX_FILE = "diffusion_pytorch_model.bin"
......
...@@ -47,6 +47,9 @@ default_cache_path = os.path.join(hf_cache_home, "diffusers") ...@@ -47,6 +47,9 @@ default_cache_path = os.path.join(hf_cache_home, "diffusers")
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
......
...@@ -46,11 +46,10 @@ from diffusers import ( ...@@ -46,11 +46,10 @@ from diffusers import (
UNet2DModel, UNet2DModel,
VQModel, VQModel,
) )
from diffusers.modeling_utils import WEIGHTS_NAME
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils import CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment