Unverified Commit e391b789 authored by JinK's avatar JinK Committed by GitHub
Browse files

Support different strength for Stable Diffusion TensorRT Inpainting pipeline (#4216)

* Support different strength

* run make style
parent d0b8de12
...@@ -823,14 +823,14 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -823,14 +823,14 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
return self return self
def __initialize_timesteps(self, timesteps, strength): def __initialize_timesteps(self, num_inference_steps, strength):
self.scheduler.set_timesteps(timesteps) self.scheduler.set_timesteps(num_inference_steps)
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 offset = self.scheduler.config.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
init_timestep = int(timesteps * strength) + offset init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, timesteps) init_timestep = min(init_timestep, num_inference_steps)
t_start = max(timesteps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].to(self.torch_device)
return timesteps, t_start return timesteps, num_inference_steps - t_start
def __preprocess_images(self, batch_size, images=()): def __preprocess_images(self, batch_size, images=()):
init_images = [] init_images = []
...@@ -953,7 +953,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -953,7 +953,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
strength: float = 0.75, strength: float = 1.0,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -1043,9 +1043,32 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -1043,9 +1043,32 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
latent_height = self.image_height // 8 latent_height = self.image_height // 8
latent_width = self.image_width // 8 latent_width = self.image_width // 8
# Pre-process input images
mask, masked_image, init_image = self.__preprocess_images(
batch_size,
prepare_mask_and_masked_image(
image,
mask_image,
self.image_height,
self.image_width,
return_image=True,
),
)
# print(mask)
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
mask = torch.cat([mask] * 2)
# Initialize timesteps
timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
# Pre-initialize latents # Pre-initialize latents
num_channels_latents = self.vae.config.latent_channels num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents( latents_outputs = self.prepare_latents(
batch_size, batch_size,
num_channels_latents, num_channels_latents,
self.image_height, self.image_height,
...@@ -1053,16 +1076,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -1053,16 +1076,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
torch.float32, torch.float32,
self.torch_device, self.torch_device,
generator, generator,
image=init_image,
timestep=latent_timestep,
is_strength_max=is_strength_max,
) )
# Pre-process input images latents = latents_outputs[0]
mask, masked_image = self.__preprocess_images(batch_size, prepare_mask_and_masked_image(image, mask_image))
# print(mask)
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
mask = torch.cat([mask] * 2)
# Initialize timesteps
timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
# VAE encode masked image # VAE encode masked image
masked_latents = self.__encode_image(masked_image) masked_latents = self.__encode_image(masked_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment