Unverified Commit edc154da authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Update Ruff to latest Version (#10919)

* update

* update

* update

* update
parent 552cd320
......@@ -783,7 +783,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -26,8 +26,7 @@
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
"from diffusers import StableDiffusionGLIGENPipeline"
]
},
{
......@@ -36,28 +35,25 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from transformers import CLIPTextModel, CLIPTokenizer\n",
"\n",
"import diffusers\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDPMScheduler,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n",
" UNet2DConditionModel,\n",
")\n",
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
"\n",
"\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
"pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
")\n",
"text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
"vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
......@@ -71,9 +67,7 @@
"metadata": {},
"outputs": [],
"source": [
"unet = UNet2DConditionModel.from_pretrained(\n",
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
")"
"unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
]
},
{
......@@ -108,6 +102,9 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n",
......@@ -117,10 +114,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
"gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"\n",
"boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n",
......@@ -166,7 +161,7 @@
"metadata": {},
"outputs": [],
"source": [
"diffusers.utils.make_image_grid(images, 4, len(images)//4)"
"diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
]
},
{
......@@ -179,7 +174,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "densecaption",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
......@@ -197,5 +192,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
......@@ -15,8 +15,8 @@
# limitations under the License.
"""
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
......
......@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",")
assert all(
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
), "Instance data dir and prompt inputs are not of the same length."
assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
"Instance data dir and prompt inputs are not of the same length."
)
if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",")
......@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts
assert num_of_validation_prompts == len(
negative_validation_prompts
), "The length of negative prompts for validation is greater than the number of validation prompts."
assert num_of_validation_prompts == len(negative_validation_prompts), (
"The length of negative prompts for validation is greater than the number of validation prompts."
)
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
......
......@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
......
......@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
......
......@@ -663,8 +663,7 @@ class PromptDiffusionPipeline(
self.check_image(image, prompt, prompt_embeds)
else:
raise ValueError(
f"You have passed a list of images of length {len(image_pair)}."
f"Make sure the list size equals to two."
f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
)
# Check `controlnet_conditioning_scale`
......
......@@ -173,7 +173,7 @@ class TrainSD:
if not dataloader_exception:
xm.wait_device_ops()
total_time = time.time() - last_time
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
else:
print("dataloader exception happen, skip result")
return
......@@ -622,7 +622,7 @@ def main(args):
num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
......
......@@ -1057,7 +1057,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
......
......@@ -1021,7 +1021,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
......
......@@ -118,7 +118,7 @@ def save_model_card(
)
model_description = f"""
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
<Gallery />
......@@ -1336,7 +1336,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
......
......@@ -750,7 +750,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
......
......@@ -765,7 +765,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
......@@ -767,7 +767,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
......
......@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
......
......@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
index_no_updates_2
] = orig_embeds_params_2[index_no_updates_2]
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
orig_embeds_params_2[index_no_updates_2]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
......
......@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--seed=0
""".split()
......@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--use_ema
--seed=0
......@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate):
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--checkpoints_total_limit=2
--seed=0
""".split()
......
......@@ -653,15 +653,15 @@ def main():
try:
# Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1]
assert isinstance(
timm_centercrop_transform, transforms.CenterCrop
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
)
timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization
timm_model_normalization = timm_transform.transforms[-1]
assert isinstance(
timm_model_normalization, transforms.Normalize
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
assert isinstance(timm_model_normalization, transforms.Normalize), (
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
)
except AssertionError as e:
raise NotImplementedError(e)
# Enable flash attention if asked
......
......@@ -3,7 +3,7 @@ line-length = 119
[tool.ruff.lint]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"]
ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment