Unverified Commit 0f14335a authored by Aki Sakurai's avatar Aki Sakurai Committed by GitHub
Browse files

StableDiffusionLongPromptWeightingPipeline: Do not hardcode pad token (#2832)

parent 8bdf4236
...@@ -179,14 +179,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m ...@@ -179,14 +179,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m
return tokens, weights return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
r""" r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
""" """
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)): for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
if no_boseos_middle: if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else: else:
...@@ -317,12 +317,14 @@ def get_weighted_text_embeddings( ...@@ -317,12 +317,14 @@ def get_weighted_text_embeddings(
# pad the length of tokens and weights # pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id eos = pipe.tokenizer.eos_token_id
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens, prompt_tokens,
prompt_weights, prompt_weights,
max_length, max_length,
bos, bos,
eos, eos,
pad,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length, chunk_length=pipe.tokenizer.model_max_length,
) )
...@@ -334,6 +336,7 @@ def get_weighted_text_embeddings( ...@@ -334,6 +336,7 @@ def get_weighted_text_embeddings(
max_length, max_length,
bos, bos,
eos, eos,
pad,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length, chunk_length=pipe.tokenizer.model_max_length,
) )
......
...@@ -196,14 +196,14 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): ...@@ -196,14 +196,14 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
return tokens, weights return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
r""" r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
""" """
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)): for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
if no_boseos_middle: if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else: else:
...@@ -342,12 +342,14 @@ def get_weighted_text_embeddings( ...@@ -342,12 +342,14 @@ def get_weighted_text_embeddings(
# pad the length of tokens and weights # pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id eos = pipe.tokenizer.eos_token_id
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens, prompt_tokens,
prompt_weights, prompt_weights,
max_length, max_length,
bos, bos,
eos, eos,
pad,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length, chunk_length=pipe.tokenizer.model_max_length,
) )
...@@ -359,6 +361,7 @@ def get_weighted_text_embeddings( ...@@ -359,6 +361,7 @@ def get_weighted_text_embeddings(
max_length, max_length,
bos, bos,
eos, eos,
pad,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length, chunk_length=pipe.tokenizer.model_max_length,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment