Unverified Commit 626284f8 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[StableDiffusionXLAdapterPipeline] add adapter_conditioning_factor (#4937)

add adapter_conditioning_factor
parent 9800cc5e
...@@ -656,6 +656,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -656,6 +656,7 @@ class StableDiffusionXLAdapterPipeline(
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None, target_size: Optional[Tuple[int, int]] = None,
adapter_conditioning_scale: Union[float, List[float]] = 1.0, adapter_conditioning_scale: Union[float, List[float]] = 1.0,
adapter_conditioning_factor: float = 1.0,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -767,6 +768,10 @@ class StableDiffusionXLAdapterPipeline( ...@@ -767,6 +768,10 @@ class StableDiffusionXLAdapterPipeline(
The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
residual in the original unet. If multiple adapters are specified in init, you can set the residual in the original unet. If multiple adapters are specified in init, you can set the
corresponding scale as a list. corresponding scale as a list.
adapter_conditioning_factor (`float`, *optional*, defaults to 1.0):
The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is
`0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for
all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps.
Examples: Examples:
Returns: Returns:
...@@ -904,6 +909,12 @@ class StableDiffusionXLAdapterPipeline( ...@@ -904,6 +909,12 @@ class StableDiffusionXLAdapterPipeline(
# predict the noise residual # predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if i < int(num_inference_steps * adapter_conditioning_factor):
down_block_additional_residuals = [state.clone() for state in adapter_state]
else:
down_block_additional_residuals = None
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
t, t,
...@@ -911,7 +922,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -911,7 +922,7 @@ class StableDiffusionXLAdapterPipeline(
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
down_block_additional_residuals=[state.clone() for state in adapter_state], down_block_additional_residuals=down_block_additional_residuals,
)[0] )[0]
# perform guidance # perform guidance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment