Unverified Commit edc154da authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Update Ruff to latest Version (#10919)

* update

* update

* update

* update
parent 552cd320
...@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter ...@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter # Set alpha parameter
if "lora_down" in kohya_key: if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha' alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict return kohya_ss_state_dict
......
...@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images): ...@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total = 0 total = 0
pbar = tqdm(desc="downloading real regularization images", total=num_class_images) pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open( with (
f"{class_data_dir}/images.txt", "w" open(f"{class_data_dir}/caption.txt", "w") as f1,
) as f3: open(f"{class_data_dir}/urls.txt", "w") as f2,
open(f"{class_data_dir}/images.txt", "w") as f3,
):
while total < num_class_images: while total < num_class_images:
images = class_images[count] images = class_images[count]
count += 1 count += 1
......
...@@ -731,18 +731,18 @@ def main(args): ...@@ -731,18 +731,18 @@ def main(args):
if not class_images_dir.exists(): if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True) class_images_dir.mkdir(parents=True, exist_ok=True)
if args.real_prior: if args.real_prior:
assert ( assert (class_images_dir / "images").exists(), (
class_images_dir / "images" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
len(list((class_images_dir / "images").iterdir())) == args.num_class_images f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert (class_images_dir / "caption.txt").exists(), (
class_images_dir / "caption.txt" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert (class_images_dir / "images.txt").exists(), (
class_images_dir / "images.txt" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt") concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt") concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
args.concepts_list[i] = concept args.concepts_list[i] = concept
......
...@@ -1014,7 +1014,7 @@ def main(args): ...@@ -1014,7 +1014,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
......
...@@ -982,7 +982,7 @@ def main(args): ...@@ -982,7 +982,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
......
...@@ -1294,7 +1294,7 @@ def main(args): ...@@ -1294,7 +1294,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
...@@ -1053,7 +1053,7 @@ def main(args): ...@@ -1053,7 +1053,7 @@ def main(args):
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir) lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
...@@ -1064,7 +1064,7 @@ def main(args): ...@@ -1064,7 +1064,7 @@ def main(args):
lora_state_dict = SanaPipeline.lora_state_dict(input_dir) lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
...@@ -1355,7 +1355,7 @@ def main(args): ...@@ -1355,7 +1355,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
...@@ -118,7 +118,7 @@ def save_model_card( ...@@ -118,7 +118,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} # {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
...@@ -1286,7 +1286,7 @@ def main(args): ...@@ -1286,7 +1286,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
......
...@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f ...@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
assert ( assert pipeline.transformer.config.in_channels == initial_channels * 2, (
pipeline.transformer.config.in_channels == initial_channels * 2 f"{pipeline.transformer.config.in_channels=}"
), f"{pipeline.transformer.config.in_channels=}" )
pipeline.to(accelerator.device) pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -954,7 +954,7 @@ def main(args): ...@@ -954,7 +954,7 @@ def main(args):
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir) lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
transformer_lora_state_dict = { transformer_lora_state_dict = {
f'{k.replace("transformer.", "")}': v f"{k.replace('transformer.', '')}": v
for k, v in lora_state_dict.items() for k, v in lora_state_dict.items()
if k.startswith("transformer.") and "lora" in k if k.startswith("transformer.") and "lora" in k
} }
......
...@@ -1081,9 +1081,9 @@ class AutoConfig: ...@@ -1081,9 +1081,9 @@ class AutoConfig:
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}" f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
) )
pretrained_model_name_or_paths[ pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
pretrained_model_name_or_paths.index(search_word) textual_inversion_path.model_path
] = textual_inversion_path.model_path )
self.load_textual_inversion( self.load_textual_inversion(
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
......
...@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string): ...@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors="pt", return_tensors="pt",
) )
tokens = batch_encoding["input_ids"] tokens = batch_encoding["input_ids"]
assert ( assert torch.count_nonzero(tokens - 49407) == 2, (
torch.count_nonzero(tokens - 49407) == 2 f"String '{string}' maps to more than a single token. Please use another string"
), f"String '{string}' maps to more than a single token. Please use another string" )
return tokens[0, 1] return tokens[0, 1]
......
...@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module): ...@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert ( assert H == self.img_size[0] and W == self.img_size[1], (
H == self.img_size[0] and W == self.img_size[1] f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." )
x = self.proj(x).flatten(2).permute(0, 2, 1) x = self.proj(x).flatten(2).permute(0, 2, 1)
return x return x
......
...@@ -619,7 +619,7 @@ def main(args): ...@@ -619,7 +619,7 @@ def main(args):
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0]) logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
......
...@@ -803,21 +803,20 @@ def parse_args(input_args=None): ...@@ -803,21 +803,20 @@ def parse_args(input_args=None):
"--control_type", "--control_type",
type=str, type=str,
default="canny", default="canny",
help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."), help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
) )
parser.add_argument( parser.add_argument(
"--transformer_layers_per_block", "--transformer_layers_per_block",
type=str, type=str,
default=None, default=None,
help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."), help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
) )
parser.add_argument( parser.add_argument(
"--old_style_controlnet", "--old_style_controlnet",
action="store_true", action="store_true",
default=False, default=False,
help=( help=(
"Use the old style controlnet, which is a single transformer layer with" "Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
" a single head. Defaults to False."
), ),
) )
......
...@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st ...@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
......
...@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path( ...@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
......
...@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path( ...@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
...@@ -683,7 +683,7 @@ def main(args): ...@@ -683,7 +683,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
......
...@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path( ...@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
...@@ -790,7 +790,7 @@ def main(args): ...@@ -790,7 +790,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment