Unverified Commit a7ca03aa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Replace flake8 with ruff and update black (#2279)

* before running make style

* remove left overs from flake8

* finish

* make fix-copies

* final fix

* more fixes
parent f5ccffec
...@@ -6,15 +6,22 @@ import random ...@@ -6,15 +6,22 @@ import random
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np
import torch
import torch.utils.checkpoint
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import optax import optax
import torch
import torch.utils.checkpoint
import transformers import transformers
from datasets import load_dataset from datasets import load_dataset
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import ( from diffusers import (
FlaxAutoencoderKL, FlaxAutoencoderKL,
FlaxDDPMScheduler, FlaxDDPMScheduler,
...@@ -24,13 +31,6 @@ from diffusers import ( ...@@ -24,13 +31,6 @@ from diffusers import (
) )
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
......
...@@ -22,28 +22,28 @@ import random ...@@ -22,28 +22,28 @@ import random
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import datasets
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
import datasets
import diffusers
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
......
...@@ -22,17 +22,25 @@ from pathlib import Path ...@@ -22,17 +22,25 @@ from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import PIL
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.utils.data import Dataset
import diffusers
import PIL
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDPMScheduler, DDPMScheduler,
...@@ -44,14 +52,6 @@ from diffusers import ( ...@@ -44,14 +52,6 @@ from diffusers import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
......
...@@ -6,25 +6,14 @@ import random ...@@ -6,25 +6,14 @@ import random
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np
import torch
import torch.utils.checkpoint
from torch.utils.data import Dataset
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import optax import optax
import PIL import PIL
import torch
import torch.utils.checkpoint
import transformers import transformers
from diffusers import (
FlaxAutoencoderKL,
FlaxDDPMScheduler,
FlaxPNDMScheduler,
FlaxStableDiffusionPipeline,
FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import check_min_version
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
...@@ -33,10 +22,21 @@ from huggingface_hub import HfFolder, Repository, create_repo, whoami ...@@ -33,10 +22,21 @@ from huggingface_hub import HfFolder, Repository, create_repo, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import (
FlaxAutoencoderKL,
FlaxDDPMScheduler,
FlaxPNDMScheduler,
FlaxStableDiffusionPipeline,
FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import check_min_version
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = { PIL_INTERPOLATION = {
......
...@@ -6,24 +6,24 @@ import os ...@@ -6,24 +6,24 @@ import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch
import torch.nn.functional as F
import accelerate import accelerate
import datasets import datasets
import diffusers import torch
import torch.nn.functional as F
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
import diffusers
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.13.0.dev0") check_min_version("0.13.0.dev0")
......
[tool.black] [tool.black]
line-length = 119 line-length = 119
target-version = ['py36'] target-version = ['py37']
[tool.ruff]
# Never enforce `E501` (line length violations).
ignore = ["E501", "E741", "W605"]
select = ["E", "F", "I", "W"]
line-length = 119
# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
"src/diffusers/utils/dummy_*.py" = ["F401"]
[tool.ruff.isort]
lines-after-imports = 2
known-first-party = ["diffusers"]
...@@ -19,9 +19,9 @@ import json ...@@ -19,9 +19,9 @@ import json
import os import os
import torch import torch
from transformers.file_utils import has_file
from diffusers import UNet2DConditionModel, UNet2DModel from diffusers import UNet2DConditionModel, UNet2DModel
from transformers.file_utils import has_file
do_only_config = False do_only_config = False
......
import argparse import argparse
import OmegaConf
import torch import torch
import OmegaConf
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
......
...@@ -5,11 +5,11 @@ import os ...@@ -5,11 +5,11 @@ import os
from copy import deepcopy from copy import deepcopy
import torch import torch
from audio_diffusion.models import DiffusionAttnUnet1D
from diffusion import sampling
from torch import nn from torch import nn
from audio_diffusion.models import DiffusionAttnUnet1D
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusion import sampling
MODELS_MAP = { MODELS_MAP = {
......
...@@ -7,7 +7,6 @@ import os.path as osp ...@@ -7,7 +7,6 @@ import os.path as osp
import re import re
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
......
...@@ -2,9 +2,9 @@ import argparse ...@@ -2,9 +2,9 @@ import argparse
import os import os
import torch import torch
from torchvision.datasets.utils import download_url
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
from torchvision.datasets.utils import download_url
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"} pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
......
import argparse import argparse
import torch
import huggingface_hub import huggingface_hub
import k_diffusion as K import k_diffusion as K
import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
......
...@@ -2,13 +2,13 @@ import argparse ...@@ -2,13 +2,13 @@ import argparse
import tempfile import tempfile
import torch import torch
from accelerate import load_checkpoint_and_dispatch from accelerate import load_checkpoint_and_dispatch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
from diffusers.models.prior_transformer import PriorTransformer from diffusers.models.prior_transformer import PriorTransformer
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
""" """
...@@ -249,7 +249,6 @@ DECODER_CONFIG = { ...@@ -249,7 +249,6 @@ DECODER_CONFIG = {
"class_embed_type": "identity", "class_embed_type": "identity",
"attention_head_dim": 64, "attention_head_dim": 64,
"resnet_time_scale_shift": "scale_shift", "resnet_time_scale_shift": "scale_shift",
"class_embed_type": "identity",
} }
......
...@@ -355,5 +355,5 @@ if __name__ == "__main__": ...@@ -355,5 +355,5 @@ if __name__ == "__main__":
pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae) pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
pipe.save_pretrained(args.dump_path) pipe.save_pretrained(args.dump_path)
except: except: # noqa: E722
model.save_pretrained(args.dump_path) model.save_pretrained(args.dump_path)
...@@ -181,5 +181,5 @@ if __name__ == "__main__": ...@@ -181,5 +181,5 @@ if __name__ == "__main__":
pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler) pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path) pipe.save_pretrained(args.dump_path)
except: except: # noqa: E722
model.save_pretrained(args.dump_path) model.save_pretrained(args.dump_path)
...@@ -17,12 +17,12 @@ import os ...@@ -17,12 +17,12 @@ import os
import shutil import shutil
from pathlib import Path from pathlib import Path
import onnx
import torch import torch
from packaging import version
from torch.onnx import export from torch.onnx import export
import onnx
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline
from packaging import version
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
......
import argparse import argparse
from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
import argparse import argparse
import io import io
import requests
import torch import torch
from omegaconf import OmegaConf
import requests
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
assign_to_checkpoint, assign_to_checkpoint,
...@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( ...@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
renew_vae_attention_paths, renew_vae_attention_paths,
renew_vae_resnet_paths, renew_vae_resnet_paths,
) )
from omegaconf import OmegaConf
def custom_convert_ldm_vae_checkpoint(checkpoint, config): def custom_convert_ldm_vae_checkpoint(checkpoint, config):
......
...@@ -18,6 +18,12 @@ import argparse ...@@ -18,6 +18,12 @@ import argparse
from argparse import Namespace from argparse import Namespace
import torch import torch
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -31,12 +37,6 @@ from diffusers import ( ...@@ -31,12 +37,6 @@ from diffusers import (
VersatileDiffusionPipeline, VersatileDiffusionPipeline,
) )
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
SCHEDULER_CONFIG = Namespace( SCHEDULER_CONFIG = Namespace(
......
...@@ -36,14 +36,14 @@ import argparse ...@@ -36,14 +36,14 @@ import argparse
import tempfile import tempfile
import torch import torch
import yaml import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader from yaml.loader import FullLoader
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
try: try:
from omegaconf import OmegaConf from omegaconf import OmegaConf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment