"src/vscode:/vscode.git/clone" did not exist on "4f00d5ac6fa408a9ca73141db5e8d0cbb1881d92"
Unverified Commit beb8f216 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Clean up LCM Pipeline and Test Code. (#5641)

* Clean up LCM pipeline and pipeline test code.

* Add comment for LCM img2img sampling loop.
parent 7ad70cee
...@@ -518,6 +518,36 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -518,6 +518,36 @@ class LatentConsistencyModelImg2ImgPipeline(
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
def check_inputs(
self,
prompt: Union[str, List[str]],
strength: float,
callback_steps: int,
prompt_embeds: Optional[torch.FloatTensor] = None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -602,16 +632,9 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -602,16 +632,9 @@ class LatentConsistencyModelImg2ImgPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
# 1. Define call parameters # 1. Check inputs. Raise error if not correct
if prompt is not None and isinstance(prompt, str): self.check_inputs(prompt, strength, callback_steps, prompt_embeds)
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
#
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -641,10 +664,10 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -641,10 +664,10 @@ class LatentConsistencyModelImg2ImgPipeline(
clip_skip=clip_skip, clip_skip=clip_skip,
) )
# 3.5 encode image # 4. Encode image
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
# 4. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps( self.scheduler.set_timesteps(
num_inference_steps, device, original_inference_steps=original_inference_steps, strength=strength num_inference_steps, device, original_inference_steps=original_inference_steps, strength=strength
) )
...@@ -674,6 +697,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -674,6 +697,7 @@ class LatentConsistencyModelImg2ImgPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
# 8. LCM Multistep Sampling Loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
......
...@@ -136,9 +136,8 @@ class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -136,9 +136,8 @@ class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, Pipelin
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
# TODO: get expected slice expected_slice = np.array([0.1403, 0.5072, 0.5316, 0.1202, 0.3865, 0.4211, 0.5363, 0.3557, 0.3645])
expected_slice = np.array([0.1540, 0.5205, 0.5458, 0.1200, 0.3983, 0.4350, 0.5386, 0.3522, 0.3614]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4) super().test_inference_batch_single_identical(expected_max_diff=5e-4)
......
...@@ -150,7 +150,7 @@ class LatentConsistencyModelImg2ImgPipelineFastTests( ...@@ -150,7 +150,7 @@ class LatentConsistencyModelImg2ImgPipelineFastTests(
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.4903, 0.3304, 0.3503, 0.5241, 0.5153, 0.4585, 0.3222, 0.4764, 0.4891]) expected_slice = np.array([0.4903, 0.3304, 0.3503, 0.5241, 0.5153, 0.4585, 0.3222, 0.4764, 0.4891])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4) super().test_inference_batch_single_identical(expected_max_diff=5e-4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment