Unverified Commit 326de419 authored by Ben Sherman's avatar Ben Sherman Committed by GitHub
Browse files

Trivial fix for undefined symbol in train_dreambooth.py (#1598)

easy fix for undefined name in train_dreambooth.py

import_model_class_from_model_name_or_path loads a pretrained model
and refers to args.revision in a context where args is undefined. I modified
the function to take revision as an argument and modified the invocation
of the function to pass in the revision from args. Seems like this was caused
by a cut and paste.
parent eb1abee6
...@@ -30,11 +30,11 @@ check_min_version("0.10.0.dev0") ...@@ -30,11 +30,11 @@ check_min_version("0.10.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
subfolder="text_encoder", subfolder="text_encoder",
revision=args.revision, revision=revision,
) )
model_class = text_encoder_config.architectures[0] model_class = text_encoder_config.architectures[0]
...@@ -469,7 +469,7 @@ def main(args): ...@@ -469,7 +469,7 @@ def main(args):
) )
# import correct text encoder class # import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = text_encoder_cls.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment