Unverified Commit db5194a4 authored by C Q's avatar C Q Committed by GitHub
Browse files

Fix Compatibility Issues in stable_diffusion_xl_reference.py (#6251)



* Fix Compatibility Issues in stable_diffusion_xl_reference.py

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent e6c9c251
...@@ -507,7 +507,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -507,7 +507,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
return hidden_states, output_states return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs): def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
eps = 1e-6 eps = 1e-6
output_states = () output_states = ()
...@@ -686,8 +686,17 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -686,8 +686,17 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
# 10. Prepare added time ids & embeddings # 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids( add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment