Unverified Commit edc154da authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Update Ruff to latest Version (#10919)

* update

* update

* update

* update
parent 552cd320
...@@ -783,7 +783,7 @@ def main(args): ...@@ -783,7 +783,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
...@@ -26,8 +26,7 @@ ...@@ -26,8 +26,7 @@
"%load_ext autoreload\n", "%load_ext autoreload\n",
"%autoreload 2\n", "%autoreload 2\n",
"\n", "\n",
"import torch\n", "from diffusers import StableDiffusionGLIGENPipeline"
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
] ]
}, },
{ {
...@@ -36,28 +35,25 @@ ...@@ -36,28 +35,25 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from transformers import CLIPTextModel, CLIPTokenizer\n",
"\n",
"import diffusers\n", "import diffusers\n",
"from diffusers import (\n", "from diffusers import (\n",
" AutoencoderKL,\n", " AutoencoderKL,\n",
" DDPMScheduler,\n", " DDPMScheduler,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n", " EulerDiscreteScheduler,\n",
" UNet2DConditionModel,\n",
")\n", ")\n",
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", "\n",
"\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n", "\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"\n", "\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n", "tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n", "noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n", "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n", "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n", "# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n", "# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n", "# )\n",
...@@ -71,9 +67,7 @@ ...@@ -71,9 +67,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"unet = UNet2DConditionModel.from_pretrained(\n", "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
")"
] ]
}, },
{ {
...@@ -108,6 +102,9 @@ ...@@ -108,6 +102,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n",
"\n",
"\n",
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n", "# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n", "# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n", "\n",
...@@ -117,10 +114,8 @@ ...@@ -117,10 +114,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n", "# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n", "# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n", "\n",
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n", "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n", "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n", "\n",
"boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n", "boxes = boxes / 512\n",
...@@ -166,7 +161,7 @@ ...@@ -166,7 +161,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"diffusers.utils.make_image_grid(images, 4, len(images)//4)" "diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
] ]
}, },
{ {
...@@ -179,7 +174,7 @@ ...@@ -179,7 +174,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "densecaption", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
...@@ -197,5 +192,5 @@ ...@@ -197,5 +192,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 4
} }
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" """
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix. Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
""" """
import argparse import argparse
......
...@@ -763,9 +763,9 @@ def main(args): ...@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match # Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",") instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",") instance_prompt = args.instance_prompt.split(",")
assert all( assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] "Instance data dir and prompt inputs are not of the same length."
), "Instance data dir and prompt inputs are not of the same length." )
if args.with_prior_preservation: if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",") class_data_dir = args.class_data_dir.split(",")
...@@ -788,9 +788,9 @@ def main(args): ...@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None) negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts args.validation_negative_prompt = negative_validation_prompts
assert num_of_validation_prompts == len( assert num_of_validation_prompts == len(negative_validation_prompts), (
negative_validation_prompts "The length of negative prompts for validation is greater than the number of validation prompts."
), "The length of negative prompts for validation is greater than the number of validation prompts." )
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
......
...@@ -830,9 +830,9 @@ def main(): ...@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator) index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
......
...@@ -886,9 +886,9 @@ def main(): ...@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
......
...@@ -663,8 +663,7 @@ class PromptDiffusionPipeline( ...@@ -663,8 +663,7 @@ class PromptDiffusionPipeline(
self.check_image(image, prompt, prompt_embeds) self.check_image(image, prompt, prompt_embeds)
else: else:
raise ValueError( raise ValueError(
f"You have passed a list of images of length {len(image_pair)}." f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
f"Make sure the list size equals to two."
) )
# Check `controlnet_conditioning_scale` # Check `controlnet_conditioning_scale`
......
...@@ -173,7 +173,7 @@ class TrainSD: ...@@ -173,7 +173,7 @@ class TrainSD:
if not dataloader_exception: if not dataloader_exception:
xm.wait_device_ops() xm.wait_device_ops()
total_time = time.time() - last_time total_time = time.time() - last_time
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
else: else:
print("dataloader exception happen, skip result") print("dataloader exception happen, skip result")
return return
...@@ -622,7 +622,7 @@ def main(args): ...@@ -622,7 +622,7 @@ def main(args):
num_devices_per_host = num_devices // num_hosts num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal(): if xm.is_master_ordinal():
print("***** Running training *****") print("***** Running training *****")
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }") print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
print( print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}" f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
) )
......
...@@ -1057,7 +1057,7 @@ def main(args): ...@@ -1057,7 +1057,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
......
...@@ -1021,7 +1021,7 @@ def main(args): ...@@ -1021,7 +1021,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
......
...@@ -118,7 +118,7 @@ def save_model_card( ...@@ -118,7 +118,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} # {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
...@@ -1336,7 +1336,7 @@ def main(args): ...@@ -1336,7 +1336,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
......
...@@ -750,7 +750,7 @@ def main(args): ...@@ -750,7 +750,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
......
...@@ -765,7 +765,7 @@ def main(args): ...@@ -765,7 +765,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
...@@ -767,7 +767,7 @@ def main(args): ...@@ -767,7 +767,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
......
...@@ -910,9 +910,9 @@ def main(): ...@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
......
...@@ -965,12 +965,12 @@ def main(): ...@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
index_no_updates_2 orig_embeds_params_2[index_no_updates_2]
] = orig_embeds_params_2[index_no_updates_2] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
......
...@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate): ...@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path} --model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1 --checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir} --output_dir {tmpdir}
--seed=0 --seed=0
""".split() """.split()
...@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate): ...@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path} --model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1 --checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir} --output_dir {tmpdir}
--use_ema --use_ema
--seed=0 --seed=0
...@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate): ...@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate):
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir} --output_dir {tmpdir}
--checkpointing_steps=2 --checkpointing_steps=2
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--checkpoints_total_limit=2 --checkpoints_total_limit=2
--seed=0 --seed=0
""".split() """.split()
......
...@@ -653,15 +653,15 @@ def main(): ...@@ -653,15 +653,15 @@ def main():
try: try:
# Gets the resolution of the timm transformation after centercrop # Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1] timm_centercrop_transform = timm_transform.transforms[1]
assert isinstance( assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
timm_centercrop_transform, transforms.CenterCrop f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." )
timm_model_resolution = timm_centercrop_transform.size[0] timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization # Gets final normalization
timm_model_normalization = timm_transform.transforms[-1] timm_model_normalization = timm_transform.transforms[-1]
assert isinstance( assert isinstance(timm_model_normalization, transforms.Normalize), (
timm_model_normalization, transforms.Normalize f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." )
except AssertionError as e: except AssertionError as e:
raise NotImplementedError(e) raise NotImplementedError(e)
# Enable flash attention if asked # Enable flash attention if asked
......
...@@ -3,7 +3,7 @@ line-length = 119 ...@@ -3,7 +3,7 @@ line-length = 119
[tool.ruff.lint] [tool.ruff.lint]
# Never enforce `E501` (line length violations). # Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"] ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files. # Ignore import violations in all `__init__.py` files.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment