Unverified Commit 8c48ec05 authored by Alrott SlimRG's avatar Alrott SlimRG Committed by GitHub
Browse files

Fix bf15/fp16 for pipeline_wan_vace.py (#12143)



* Fix bf15/fp16 for pipeline_wan_vace.py

* Update pipeline_wan_vace.py

* try removing xfail decorator

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent a6d2fc2c
...@@ -525,8 +525,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -525,8 +525,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0) latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype) latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
else: else:
mask = mask.to(dtype=vae_dtype) mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
mask = torch.where(mask > 0.5, 1.0, 0.0)
inactive = video * (1 - mask) inactive = video * (1 - mask)
reactive = video * mask reactive = video * mask
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax") inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
......
...@@ -18,7 +18,6 @@ import tempfile ...@@ -18,7 +18,6 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
import pytest
import safetensors.torch import safetensors.torch
import torch import torch
from PIL import Image from PIL import Image
...@@ -160,11 +159,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -160,11 +159,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
pass pass
@pytest.mark.xfail(
condition=True,
reason="RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same",
strict=True,
)
def test_layerwise_casting_inference_denoiser(self): def test_layerwise_casting_inference_denoiser(self):
super().test_layerwise_casting_inference_denoiser() super().test_layerwise_casting_inference_denoiser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment