Unverified Commit 7c3e7fed authored by Jacqui Wei's avatar Jacqui Wei Committed by GitHub
Browse files

Fix `use_onnx` parameter usage in `from_pretrained` func and update...

Fix `use_onnx` parameter usage in `from_pretrained` func and update `test_download_no_onnx_by_default` test (#4508)

* add missing use_onnx in from_pretrained func

* fix test_download_no_onnx_by_default test func

* address comments

* split test cases
parent 029fb416
......@@ -924,6 +924,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
# 1. Download the checkpoints and configs
......@@ -940,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
revision=revision,
from_flax=from_flax,
use_safetensors=use_safetensors,
use_onnx=use_onnx,
custom_pipeline=custom_pipeline,
custom_revision=custom_revision,
variant=variant,
......
......@@ -76,6 +76,7 @@ from diffusers.utils.testing_utils import (
load_numpy,
require_compel,
require_flax,
require_onnxruntime,
require_torch_gpu,
run_test_in_subprocess,
)
......@@ -327,28 +328,30 @@ class DownloadTests(unittest.TestCase):
def test_download_no_onnx_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
"hf-internal-testing/tiny-stable-diffusion-xl-pipe",
cache_dir=tmpdirname,
use_safetensors=False,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no onnx weights are downloaded
# make sure that by default no onnx weights are downloaded for non-ONNX pipelines
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
@require_onnxruntime
def test_download_onnx_by_default_for_onnx_pipelines(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
use_onnx=True,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# if `use_onnx` is specified make sure weights are downloaded
# make sure that by default onnx weights are downloaded for ONNX pipelines
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert any((f.endswith(".onnx")) for f in files)
assert any((f.endswith(".pb")) for f in files)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment