Unverified Commit 880c0fdd authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[advanced dreambooth lora training script][bug_fix] change token_abstraction type to str (#6040)



* improve help tags

* style fix

* changes token_abstraction type to string.
support multiple concepts for pivotal using a comma separated string.

* style fixup

* changed logger to warning (not yet available)

* moved the token_abstraction parsing to be in the same block as where we create the mapping of identifier to token

---------
Co-authored-by: default avatarLinoy <linoy@huggingface.co>
parent c36f1c31
...@@ -300,16 +300,18 @@ def parse_args(input_args=None): ...@@ -300,16 +300,18 @@ def parse_args(input_args=None):
) )
parser.add_argument( parser.add_argument(
"--token_abstraction", "--token_abstraction",
type=str,
default="TOK", default="TOK",
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
"captions - e.g. TOK", "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
"'TOK,TOK2,TOK3' etc.",
) )
parser.add_argument( parser.add_argument(
"--num_new_tokens_per_abstraction", "--num_new_tokens_per_abstraction",
type=int, type=int,
default=2, default=2,
help="number of new tokens inserted to the tokenizers per token_abstraction value when " help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
"tokens - <si><si+1> ", "tokens - <si><si+1> ",
) )
...@@ -660,17 +662,6 @@ def parse_args(input_args=None): ...@@ -660,17 +662,6 @@ def parse_args(input_args=None):
"inversion training check `--train_text_encoder_ti`" "inversion training check `--train_text_encoder_ti`"
) )
if args.train_text_encoder_ti:
if isinstance(args.token_abstraction, str):
args.token_abstraction = [args.token_abstraction]
elif isinstance(args.token_abstraction, List):
args.token_abstraction = args.token_abstraction
else:
raise ValueError(
f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
)
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank: if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank args.local_rank = env_local_rank
...@@ -1155,9 +1146,14 @@ def main(args): ...@@ -1155,9 +1146,14 @@ def main(args):
) )
if args.train_text_encoder_ti: if args.train_text_encoder_ti:
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
# TOK2" -> ["TOK", "TOK2"] etc.
token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
logger.info(f"list of token identifiers: {token_abstraction_list}")
token_abstraction_dict = {} token_abstraction_dict = {}
token_idx = 0 token_idx = 0
for i, token in enumerate(args.token_abstraction): for i, token in enumerate(token_abstraction_list):
token_abstraction_dict[token] = [ token_abstraction_dict[token] = [
f"<s{token_idx + i + j}>" for j in range(args.num_new_tokens_per_abstraction) f"<s{token_idx + i + j}>" for j in range(args.num_new_tokens_per_abstraction)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment