Unverified Commit 65ef7a0c authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix prompt bug in AnimateDiff (#5702)

* fix prompt bug

* add test
parent 6e68c715
...@@ -498,7 +498,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo ...@@ -498,7 +498,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 16, num_frames: Optional[int] = 16,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
......
...@@ -220,6 +220,17 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -220,6 +220,17 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_prompt_embeds(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
inputs.pop("prompt")
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
pipe(**inputs)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment