"torchvision/vscode:/vscode.git/clone" did not exist on "16d62e3072955bd92b76a4ae73fefa73ecc9ee3e"
Unverified Commit e9636216 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[PixArt] fix small nits in pixart sigma (#7767)

fix small nits in pixart sigma
parent 39215aa3
...@@ -273,15 +273,6 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -273,15 +273,6 @@ class PixArtAlphaPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt( def encode_prompt(
self, self,
......
...@@ -199,16 +199,7 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -199,16 +199,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.py # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt( def encode_prompt(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
...@@ -369,7 +360,7 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -369,7 +360,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.py # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
...@@ -462,7 +453,7 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -462,7 +453,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
return [process(t) for t in text] return [process(t) for t in text]
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline._clean_caption # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption): def _clean_caption(self, caption):
caption = str(caption) caption = str(caption)
caption = ul.unquote_plus(caption) caption = ul.unquote_plus(caption)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment