Unverified Commit 86856993 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Mv weights name consts to diffusers.utils (#605)

parent f8100600
......@@ -28,12 +28,17 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import WEIGHTS_NAME, load_state_dict
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
from .modeling_utils import load_state_dict
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
logging,
)
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
logger = logging.get_logger(__name__)
......
......@@ -24,10 +24,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
logger = logging.get_logger(__name__)
......
......@@ -24,16 +24,13 @@ import numpy as np
from huggingface_hub import hf_hub_download
from .utils import is_onnx_available, logging
from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
if is_onnx_available():
import onnxruntime as ort
ONNX_WEIGHTS_NAME = "model.onnx"
logger = logging.get_logger(__name__)
......
......@@ -30,10 +30,8 @@ from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .modeling_utils import WEIGHTS_NAME
from .onnx_utils import ONNX_WEIGHTS_NAME
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin"
......
......@@ -47,6 +47,9 @@ default_cache_path = os.path.join(hf_cache_home, "diffusers")
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
......
......@@ -46,11 +46,10 @@ from diffusers import (
UNet2DModel,
VQModel,
)
from diffusers.modeling_utils import WEIGHTS_NAME
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils import CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment