Unverified Commit ba74a8be authored by SkyTNT's avatar SkyTNT Committed by GitHub
Browse files

[Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffusion (#925)

[Community Pipelines] fix pad_tokens_and_weights in lpw_stable_diffusion
parent 6f6eef74
...@@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len ...@@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
""" """
tokens = [] tokens = []
weights = [] weights = []
truncated = False
for text in prompt: for text in prompt:
texts_and_weights = parse_prompt_attention(text) texts_and_weights = parse_prompt_attention(text)
text_token = [] text_token = []
...@@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len ...@@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
# tokenize and discard the starting and the ending token # tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1] token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token text_token += token
# copy the weight by length of token # copy the weight by length of token
text_weight += [weight] * len(token) text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit) # stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length: if len(text_token) > max_length:
truncated = True
break break
# truncate # truncate
if len(text_token) > max_length: if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length] text_token = text_token[:max_length]
text_weight = text_weight[:max_length] text_weight = text_weight[:max_length]
tokens.append(text_token) tokens.append(text_token)
weights.append(text_weight) weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights return tokens, weights
...@@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd ...@@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
if len(weights[i]) == 0: if len(weights[i]) == 0:
w = [1.0] * weights_length w = [1.0] * weights_length
else: else:
for j in range((len(weights[i]) - 1) // chunk_length + 1): for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)] w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w)) w += [1.0] * (weights_length - len(w))
weights[i] = w[:] weights[i] = w[:]
...@@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd ...@@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
def get_unweighted_text_embeddings( def get_unweighted_text_embeddings(
pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True pipe: DiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
): ):
""" """
When the length of tokens is a multiple of the capacity of the text encoder, When the length of tokens is a multiple of the capacity of the text encoder,
...@@ -285,7 +289,8 @@ def get_weighted_text_embeddings( ...@@ -285,7 +289,8 @@ def get_weighted_text_embeddings(
max_length = max(max_length, max([len(token) for token in uncond_tokens])) max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min( max_embeddings_multiples = min(
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1 max_embeddings_multiples,
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
) )
max_embeddings_multiples = max(1, max_embeddings_multiples) max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
...@@ -317,12 +322,18 @@ def get_weighted_text_embeddings( ...@@ -317,12 +322,18 @@ def get_weighted_text_embeddings(
# get the embeddings # get the embeddings
text_embeddings = get_unweighted_text_embeddings( text_embeddings = get_unweighted_text_embeddings(
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
) )
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
if uncond_prompt is not None: if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings( uncond_embeddings = get_unweighted_text_embeddings(
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
) )
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
...@@ -632,16 +643,29 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -632,16 +643,29 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (
batch_size * num_images_per_prompt,
self.unet.in_channels,
height // 8,
width // 8,
)
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
# randn does not exist on mps # randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( latents = torch.randn(
self.device latents_shape,
) generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else: else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
...@@ -684,11 +708,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -684,11 +708,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
# add noise to latents using the timesteps # add noise to latents using the timesteps
if self.device.type == "mps": if self.device.type == "mps":
# randn does not exist on mps # randn does not exist on mps
noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to( noise = torch.randn(
self.device init_latents.shape,
) generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else: else:
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) noise = torch.randn(
init_latents.shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
latents = self.scheduler.add_noise(init_latents, noise, timesteps) latents = self.scheduler.add_noise(init_latents, noise, timesteps)
t_start = max(num_inference_steps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
...@@ -741,7 +773,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -741,7 +773,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
self.device self.device
) )
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) images=image,
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
) )
else: else:
has_nsfw_concept = None has_nsfw_concept = None
......
...@@ -130,6 +130,7 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): ...@@ -130,6 +130,7 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
""" """
tokens = [] tokens = []
weights = [] weights = []
truncated = False
for text in prompt: for text in prompt:
texts_and_weights = parse_prompt_attention(text) texts_and_weights = parse_prompt_attention(text)
text_token = [] text_token = []
...@@ -138,21 +139,21 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): ...@@ -138,21 +139,21 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
# tokenize and discard the starting and the ending token # tokenize and discard the starting and the ending token
token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1] token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
text_token += list(token) text_token += list(token)
# copy the weight by length of token # copy the weight by length of token
text_weight += [weight] * len(token) text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit) # stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length: if len(text_token) > max_length:
truncated = True
break break
# truncate # truncate
if len(text_token) > max_length: if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length] text_token = text_token[:max_length]
text_weight = text_weight[:max_length] text_weight = text_weight[:max_length]
tokens.append(text_token) tokens.append(text_token)
weights.append(text_weight) weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights return tokens, weights
...@@ -171,9 +172,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd ...@@ -171,9 +172,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
if len(weights[i]) == 0: if len(weights[i]) == 0:
w = [1.0] * weights_length w = [1.0] * weights_length
else: else:
for j in range((len(weights[i]) - 1) // chunk_length + 1): for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)] w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w)) w += [1.0] * (weights_length - len(w))
weights[i] = w[:] weights[i] = w[:]
...@@ -182,7 +183,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd ...@@ -182,7 +183,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
def get_unweighted_text_embeddings( def get_unweighted_text_embeddings(
pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True pipe,
text_input: np.array,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
): ):
""" """
When the length of tokens is a multiple of the capacity of the text encoder, When the length of tokens is a multiple of the capacity of the text encoder,
...@@ -276,7 +280,10 @@ def get_weighted_text_embeddings( ...@@ -276,7 +280,10 @@ def get_weighted_text_embeddings(
uncond_tokens = [ uncond_tokens = [
token[1:-1] token[1:-1]
for token in pipe.tokenizer( for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True, return_tensors="np" uncond_prompt,
max_length=max_length,
truncation=True,
return_tensors="np",
).input_ids ).input_ids
] ]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens] uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
...@@ -287,7 +294,8 @@ def get_weighted_text_embeddings( ...@@ -287,7 +294,8 @@ def get_weighted_text_embeddings(
max_length = max(max_length, max([len(token) for token in uncond_tokens])) max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min( max_embeddings_multiples = min(
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1 max_embeddings_multiples,
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
) )
max_embeddings_multiples = max(1, max_embeddings_multiples) max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
...@@ -319,12 +327,18 @@ def get_weighted_text_embeddings( ...@@ -319,12 +327,18 @@ def get_weighted_text_embeddings(
# get the embeddings # get the embeddings
text_embeddings = get_unweighted_text_embeddings( text_embeddings = get_unweighted_text_embeddings(
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
) )
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype) prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
if uncond_prompt is not None: if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings( uncond_embeddings = get_unweighted_text_embeddings(
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
) )
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype) uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
...@@ -559,7 +573,12 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -559,7 +573,12 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
noise = None noise = None
if init_image is None: if init_image is None:
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) latents_shape = (
batch_size * num_images_per_prompt,
4,
height // 8,
width // 8,
)
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
...@@ -625,7 +644,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -625,7 +644,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings sample=latent_model_input,
timestep=np.array([t]),
encoder_hidden_states=text_embeddings,
) )
noise_pred = noise_pred[0] noise_pred = noise_pred[0]
...@@ -640,7 +661,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -640,7 +661,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
if mask is not None: if mask is not None:
# masking # masking
init_latents_proper = self.scheduler.add_noise( init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t]) torch.from_numpy(init_latents_orig),
torch.from_numpy(noise),
torch.tensor([t]),
).numpy() ).numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask)) latents = (init_latents_proper * mask) + (latents * (1 - mask))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment