Unverified Commit fbff43ac authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

[FEAT] DDUF format (#10037)



* load and save dduf archive

* style

* switch to zip uncompressed

* updates

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* first draft

* remove print

* switch to dduf_file for consistency

* switch to huggingface hub api

* fix log

* add a basic test

* Update src/diffusers/configuration_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* fix

* fix variant

* change saving logic

* DDUF - Load transformers components manually (#10171)

* update hfh version

* Load transformers components manually

* load encoder from_pretrained with state_dict

* working version with transformers and tokenizer !

* add generation_config case

* fix tests

* remove saving for now

* typing

* need next version from transformers

* Update src/diffusers/configuration_utils.py
Co-authored-by: default avatarLucain <lucain@huggingface.co>

* check path corectly

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucain@huggingface.co>

* udapte

* typing

* remove check for subfolder

* quality

* revert setup changes

* oups

* more readable condition

* add loading from the hub test

* add basic docs.

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucain@huggingface.co>

* add example

* add

* make functions private

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* minor.

* fixes

* fix

* change the precdence of parameterized.

* error out when custom pipeline is passed with dduf_file.

* updates

* fix

* updates

* fixes

* updates

* fix xfail condition.

* fix xfail

* fixes

* sharded checkpoint compat

* add test for sharded checkpoint

* add suggestions

* Update src/diffusers/models/model_loading_utils.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* from suggestions

* add class attributes to flag dduf tests

* last one

* fix logic

* remove comment

* revert changes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarLucain <lucain@huggingface.co>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 3279751b
...@@ -31,6 +31,8 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM ...@@ -31,6 +31,8 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
) )
batch_params = frozenset(["prompt", "negative_prompt"]) batch_params = frozenset(["prompt", "negative_prompt"])
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
transformer = LuminaNextDiT2DModel( transformer = LuminaNextDiT2DModel(
......
...@@ -65,6 +65,8 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -65,6 +65,8 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
) )
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -56,6 +56,8 @@ class KolorsPAGPipelineFastTests( ...@@ -56,6 +56,8 @@ class KolorsPAGPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
supports_dduf = False
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components # Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
def get_dummy_components(self, time_cond_proj_dim=None): def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -53,6 +53,8 @@ class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -53,6 +53,8 @@ class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
) )
test_xformers_attention = False test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
transformer = SanaTransformer2DModel( transformer = SanaTransformer2DModel(
......
...@@ -82,6 +82,8 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests( ...@@ -82,6 +82,8 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"} {"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
) )
supports_dduf = False
# based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.get_dummy_components # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.get_dummy_components
def get_dummy_components( def get_dummy_components(
self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False
......
...@@ -82,6 +82,8 @@ class StableDiffusionXLPAGInpaintPipelineFastTests( ...@@ -82,6 +82,8 @@ class StableDiffusionXLPAGInpaintPipelineFastTests(
{"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"} {"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"}
) )
supports_dduf = False
# based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components
def get_dummy_components( def get_dummy_components(
self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False
......
...@@ -46,6 +46,8 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -46,6 +46,8 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -50,6 +50,8 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -50,6 +50,8 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
test_xformers_attention = False test_xformers_attention = False
supports_dduf = False
@property @property
def text_embedder_hidden_size(self): def text_embedder_hidden_size(self):
return 16 return 16
......
...@@ -70,6 +70,7 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -70,6 +70,7 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
) )
# There is not xformers version of the StableAudioPipeline custom attention processor # There is not xformers version of the StableAudioPipeline custom attention processor
test_xformers_attention = False test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -76,6 +76,8 @@ class StableDiffusionDepth2ImgPipelineFastTests( ...@@ -76,6 +76,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"}) callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"})
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -389,6 +389,8 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM ...@@ -389,6 +389,8 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM
class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
supports_dduf = False
def get_dummy_components(self, time_cond_proj_dim=None): def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
......
...@@ -66,6 +66,8 @@ class GligenTextImagePipelineFastTests( ...@@ -66,6 +66,8 @@ class GligenTextImagePipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -58,6 +58,8 @@ class StableDiffusionImageVariationPipelineFastTests( ...@@ -58,6 +58,8 @@ class StableDiffusionImageVariationPipelineFastTests(
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([]) image_latents_params = frozenset([])
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -422,6 +422,8 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -422,6 +422,8 @@ class StableDiffusionXLAdapterPipelineFastTests(
class StableDiffusionXLMultiAdapterPipelineFastTests( class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
): ):
supports_dduf = False
def get_dummy_components(self, time_cond_proj_dim=None): def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
......
...@@ -77,6 +77,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests( ...@@ -77,6 +77,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"} {"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
) )
supports_dduf = False
def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None): def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -72,6 +72,8 @@ class StableDiffusionXLInpaintPipelineFastTests( ...@@ -72,6 +72,8 @@ class StableDiffusionXLInpaintPipelineFastTests(
} }
) )
supports_dduf = False
def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None): def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
......
...@@ -51,6 +51,8 @@ class StableUnCLIPImg2ImgPipelineFastTests( ...@@ -51,6 +51,8 @@ class StableUnCLIPImg2ImgPipelineFastTests(
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([]) image_latents_params = frozenset([])
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
embedder_hidden_size = 32 embedder_hidden_size = 32
embedder_projection_dim = embedder_hidden_size embedder_projection_dim = embedder_hidden_size
......
...@@ -58,6 +58,8 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -58,6 +58,8 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
] ]
) )
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNetSpatioTemporalConditionModel( unet = UNetSpatioTemporalConditionModel(
......
...@@ -75,9 +75,11 @@ from diffusers.utils.testing_utils import ( ...@@ -75,9 +75,11 @@ from diffusers.utils.testing_utils import (
nightly, nightly,
require_compel, require_compel,
require_flax, require_flax,
require_hf_hub_version_greater,
require_onnxruntime, require_onnxruntime,
require_torch_2, require_torch_2,
require_torch_gpu, require_torch_gpu,
require_transformers_version_greater,
run_test_in_subprocess, run_test_in_subprocess,
slow, slow,
torch_device, torch_device,
...@@ -981,6 +983,18 @@ class DownloadTests(unittest.TestCase): ...@@ -981,6 +983,18 @@ class DownloadTests(unittest.TestCase):
assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files) assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files)
assert len(files) == 14 assert len(files) == 14
def test_download_dduf_with_custom_pipeline_raises_error(self):
with self.assertRaises(NotImplementedError):
_ = DiffusionPipeline.download(
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline"
)
def test_download_dduf_with_connected_pipeline_raises_error(self):
with self.assertRaises(NotImplementedError):
_ = DiffusionPipeline.download(
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
)
def test_get_pipeline_class_from_flax(self): def test_get_pipeline_class_from_flax(self):
flax_config = {"_class_name": "FlaxStableDiffusionPipeline"} flax_config = {"_class_name": "FlaxStableDiffusionPipeline"}
config = {"_class_name": "StableDiffusionPipeline"} config = {"_class_name": "StableDiffusionPipeline"}
...@@ -1802,6 +1816,55 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1802,6 +1816,55 @@ class PipelineFastTests(unittest.TestCase):
sd.maybe_free_model_hooks() sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5 assert sd._offload_gpu_id == 5
@parameterized.expand([torch.float32, torch.float16])
@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_load_dduf_from_hub(self, dtype):
with tempfile.TemporaryDirectory() as tmpdir:
pipe = DiffusionPipeline.from_pretrained(
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, torch_dtype=dtype
).to(torch_device)
out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
pipe.save_pretrained(tmpdir)
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=dtype).to(torch_device)
out_2 = loaded_pipe(
prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
).images
self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_load_dduf_from_hub_local_files_only(self):
with tempfile.TemporaryDirectory() as tmpdir:
pipe = DiffusionPipeline.from_pretrained(
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir
).to(torch_device)
out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
local_files_pipe = DiffusionPipeline.from_pretrained(
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, local_files_only=True
).to(torch_device)
out_2 = local_files_pipe(
prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
).images
self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
def test_dduf_raises_error_with_custom_pipeline(self):
with self.assertRaises(NotImplementedError):
_ = DiffusionPipeline.from_pretrained(
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline"
)
def test_dduf_raises_error_with_connected_pipeline(self):
with self.assertRaises(NotImplementedError):
_ = DiffusionPipeline.from_pretrained(
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
)
def test_wrong_model(self): def test_wrong_model(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context: with self.assertRaises(ValueError) as error_context:
...@@ -1812,6 +1875,27 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1812,6 +1875,27 @@ class PipelineFastTests(unittest.TestCase):
assert "is of type" in str(error_context.exception) assert "is of type" in str(error_context.exception)
assert "but should be" in str(error_context.exception) assert "but should be" in str(error_context.exception)
@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_dduf_load_sharded_checkpoint_diffusion_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-flux-dev-pipe-sharded-checkpoint-DDUF",
dduf_file="tiny-flux-dev-pipe-sharded-checkpoint.dduf",
cache_dir=tmpdir,
).to(torch_device)
out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
pipe.save_pretrained(tmpdir)
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir).to(torch_device)
out_2 = loaded_pipe(
prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
).images
self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -43,7 +43,9 @@ from diffusers.utils.testing_utils import ( ...@@ -43,7 +43,9 @@ from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
require_accelerate_version_greater, require_accelerate_version_greater,
require_accelerator, require_accelerator,
require_hf_hub_version_greater,
require_torch, require_torch,
require_transformers_version_greater,
skip_mps, skip_mps,
torch_device, torch_device,
) )
...@@ -986,6 +988,8 @@ class PipelineTesterMixin: ...@@ -986,6 +988,8 @@ class PipelineTesterMixin:
test_xformers_attention = True test_xformers_attention = True
supports_dduf = True
def get_generator(self, seed): def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu" device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device).manual_seed(seed) generator = torch.Generator(device).manual_seed(seed)
...@@ -1990,6 +1994,39 @@ class PipelineTesterMixin: ...@@ -1990,6 +1994,39 @@ class PipelineTesterMixin:
) )
) )
@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):
if not self.supports_dduf:
return
from huggingface_hub import export_folder_as_dduf
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device="cpu")
inputs.pop("generator")
inputs["generator"] = torch.manual_seed(0)
pipeline_out = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
pipe.save_pretrained(tmpdir, safe_serialization=True)
export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
inputs["generator"] = torch.manual_seed(0)
loaded_pipeline_out = loaded_pipe(**inputs)[0]
if isinstance(pipeline_out, np.ndarray) and isinstance(loaded_pipeline_out, np.ndarray):
assert np.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol)
elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor):
assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol)
@is_staging_test @is_staging_test
class PipelinePushToHubTester(unittest.TestCase): class PipelinePushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment