Unverified Commit 6bfd13f0 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[SD3 Training] T5 token limit (#8564)



* initial commit

* default back to 77

* better text

* text correction

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent eeb70033
...@@ -106,6 +106,9 @@ To better track our training experiments, we're using the following flags in the ...@@ -106,6 +106,9 @@ To better track our training experiments, we're using the following flags in the
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. * `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
> [!TIP] > [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. > You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
......
...@@ -298,6 +298,12 @@ def parse_args(input_args=None): ...@@ -298,6 +298,12 @@ def parse_args(input_args=None):
default=None, default=None,
help="The prompt to specify images in the same class as provided instance images.", help="The prompt to specify images in the same class as provided instance images.",
) )
parser.add_argument(
"--max_sequence_length",
type=int,
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument( parser.add_argument(
"--validation_prompt", "--validation_prompt",
type=str, type=str,
...@@ -830,6 +836,7 @@ def tokenize_prompt(tokenizer, prompt): ...@@ -830,6 +836,7 @@ def tokenize_prompt(tokenizer, prompt):
def _encode_prompt_with_t5( def _encode_prompt_with_t5(
text_encoder, text_encoder,
tokenizer, tokenizer,
max_sequence_length,
prompt=None, prompt=None,
num_images_per_prompt=1, num_images_per_prompt=1,
device=None, device=None,
...@@ -840,7 +847,7 @@ def _encode_prompt_with_t5( ...@@ -840,7 +847,7 @@ def _encode_prompt_with_t5(
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=77, max_length=max_sequence_length,
truncation=True, truncation=True,
add_special_tokens=True, add_special_tokens=True,
return_tensors="pt", return_tensors="pt",
...@@ -897,6 +904,7 @@ def encode_prompt( ...@@ -897,6 +904,7 @@ def encode_prompt(
text_encoders, text_encoders,
tokenizers, tokenizers,
prompt: str, prompt: str,
max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
...@@ -924,6 +932,7 @@ def encode_prompt( ...@@ -924,6 +932,7 @@ def encode_prompt(
t5_prompt_embed = _encode_prompt_with_t5( t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1], text_encoders[-1],
tokenizers[-1], tokenizers[-1],
max_sequence_length,
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device, device=device if device is not None else text_encoders[-1].device,
...@@ -1297,7 +1306,9 @@ def main(args): ...@@ -1297,7 +1306,9 @@ def main(args):
def compute_text_embeddings(prompt, text_encoders, tokenizers): def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad(): with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device) prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
......
...@@ -297,6 +297,12 @@ def parse_args(input_args=None): ...@@ -297,6 +297,12 @@ def parse_args(input_args=None):
default=None, default=None,
help="The prompt to specify images in the same class as provided instance images.", help="The prompt to specify images in the same class as provided instance images.",
) )
parser.add_argument(
"--max_sequence_length",
type=int,
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument( parser.add_argument(
"--validation_prompt", "--validation_prompt",
type=str, type=str,
...@@ -828,6 +834,7 @@ def tokenize_prompt(tokenizer, prompt): ...@@ -828,6 +834,7 @@ def tokenize_prompt(tokenizer, prompt):
def _encode_prompt_with_t5( def _encode_prompt_with_t5(
text_encoder, text_encoder,
tokenizer, tokenizer,
max_sequence_length,
prompt=None, prompt=None,
num_images_per_prompt=1, num_images_per_prompt=1,
device=None, device=None,
...@@ -838,7 +845,7 @@ def _encode_prompt_with_t5( ...@@ -838,7 +845,7 @@ def _encode_prompt_with_t5(
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
max_length=77, max_length=max_sequence_length,
truncation=True, truncation=True,
add_special_tokens=True, add_special_tokens=True,
return_tensors="pt", return_tensors="pt",
...@@ -895,6 +902,7 @@ def encode_prompt( ...@@ -895,6 +902,7 @@ def encode_prompt(
text_encoders, text_encoders,
tokenizers, tokenizers,
prompt: str, prompt: str,
max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
...@@ -922,6 +930,7 @@ def encode_prompt( ...@@ -922,6 +930,7 @@ def encode_prompt(
t5_prompt_embed = _encode_prompt_with_t5( t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1], text_encoders[-1],
tokenizers[-1], tokenizers[-1],
max_sequence_length,
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device, device=device if device is not None else text_encoders[-1].device,
...@@ -1324,7 +1333,9 @@ def main(args): ...@@ -1324,7 +1333,9 @@ def main(args):
def compute_text_embeddings(prompt, text_encoders, tokenizers): def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad(): with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device) prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment