Unverified Commit 4f0141a6 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix ONNX checkpoint loading (#2544)

* Revert "Disable ONNX tests (#2509)"

This reverts commit a0549fea.

* add external weights

* + pb

* style
parent 10219293
......@@ -31,6 +31,11 @@ jobs:
runner: docker-cpu
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: docker-cpu
......
......@@ -29,6 +29,11 @@ jobs:
runner: docker-tpu
image: diffusers/diffusers-flax-tpu
report: flax_tpu
- name: Slow ONNXRuntime CUDA tests on Ubuntu
framework: onnxruntime
runner: docker-gpu
image: diffusers/diffusers-onnxruntime-cuda
report: onnx_cuda
name: ${{ matrix.config.name }}
......
......@@ -29,6 +29,11 @@ jobs:
runner: docker-cpu
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: docker-cpu
......
......@@ -63,7 +63,7 @@ if is_transformers_available():
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
INDEX_FILE = "diffusion_pytorch_model.bin"
......@@ -176,7 +176,13 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
filenames = set(sibling.rfilename for sibling in info.siblings)
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
......@@ -604,7 +610,7 @@ class DiffusionPipeline(ConfigMixin):
]
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment