Commit 3b66cc0f authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent 717a956a
...@@ -752,9 +752,9 @@ def main(): ...@@ -752,9 +752,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
orig_embeds_params[index_no_updates] index_no_updates
) ] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
......
...@@ -749,9 +749,9 @@ def main(): ...@@ -749,9 +749,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
orig_embeds_params[index_no_updates] index_no_updates
) ] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
......
...@@ -418,9 +418,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -418,9 +418,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
): ):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -613,7 +610,12 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -613,7 +610,12 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
......
...@@ -403,9 +403,6 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -403,9 +403,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
): ):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
......
...@@ -339,9 +339,6 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -339,9 +339,6 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
): ):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
......
...@@ -414,9 +414,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -414,9 +414,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
): ):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment