"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "d3e2cee41e75bda8a645c521e8b06e8f29b84e47"
Unverified Commit 7c3e7fed authored by Jacqui Wei's avatar Jacqui Wei Committed by GitHub
Browse files

Fix `use_onnx` parameter usage in `from_pretrained` func and update...

Fix `use_onnx` parameter usage in `from_pretrained` func and update `test_download_no_onnx_by_default` test (#4508)

* add missing use_onnx in from_pretrained func

* fix test_download_no_onnx_by_default test func

* address comments

* split test cases
parent 029fb416
...@@ -924,6 +924,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -924,6 +924,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
...@@ -940,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -940,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
revision=revision, revision=revision,
from_flax=from_flax, from_flax=from_flax,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
use_onnx=use_onnx,
custom_pipeline=custom_pipeline, custom_pipeline=custom_pipeline,
custom_revision=custom_revision, custom_revision=custom_revision,
variant=variant, variant=variant,
......
...@@ -76,6 +76,7 @@ from diffusers.utils.testing_utils import ( ...@@ -76,6 +76,7 @@ from diffusers.utils.testing_utils import (
load_numpy, load_numpy,
require_compel, require_compel,
require_flax, require_flax,
require_onnxruntime,
require_torch_gpu, require_torch_gpu,
run_test_in_subprocess, run_test_in_subprocess,
) )
...@@ -327,28 +328,30 @@ class DownloadTests(unittest.TestCase): ...@@ -327,28 +328,30 @@ class DownloadTests(unittest.TestCase):
def test_download_no_onnx_by_default(self): def test_download_no_onnx_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download( tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline", "hf-internal-testing/tiny-stable-diffusion-xl-pipe",
cache_dir=tmpdirname, cache_dir=tmpdirname,
use_safetensors=False,
) )
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no onnx weights are downloaded # make sure that by default no onnx weights are downloaded for non-ONNX pipelines
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files) assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
@require_onnxruntime
def test_download_onnx_by_default_for_onnx_pipelines(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download( tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline", "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname, cache_dir=tmpdirname,
use_onnx=True,
) )
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# if `use_onnx` is specified make sure weights are downloaded # make sure that by default onnx weights are downloaded for ONNX pipelines
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert any((f.endswith(".onnx")) for f in files) assert any((f.endswith(".onnx")) for f in files)
assert any((f.endswith(".pb")) for f in files) assert any((f.endswith(".pb")) for f in files)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment