Unverified Commit a7ca03aa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Replace flake8 with ruff and update black (#2279)

* before running make style

* remove left overs from flake8

* finish

* make fix-copies

* final fix

* more fixes
parent f5ccffec
......@@ -27,9 +27,8 @@ jobs:
pip install .[quality]
- name: Check quality
run: |
black --check --preview examples tests src utils scripts
isort --check-only examples tests src utils scripts
flake8 examples tests src utils scripts
black --check examples tests src utils scripts
ruff examples tests src utils scripts
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
check_repository_consistency:
......
......@@ -169,3 +169,6 @@ tags
# dependencies
/transformers
# ruff
.ruff_cache
......@@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi
$ make style
```
🧨 Diffusers also uses `flake8` and a few custom scripts to check for coding mistakes. Quality
🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
control runs in CI, however you can also run the same checks with:
```bash
......
......@@ -9,9 +9,8 @@ modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
@if test -n "$(modified_py_files)"; then \
echo "Checking/fixing $(modified_py_files)"; \
black --preview $(modified_py_files); \
isort $(modified_py_files); \
flake8 $(modified_py_files); \
black $(modified_py_files); \
ruff $(modified_py_files); \
else \
echo "No library .py files were modified"; \
fi
......@@ -41,9 +40,8 @@ repo-consistency:
# this target runs checks on all files
quality:
black --check --preview $(check_dirs)
isort --check-only $(check_dirs)
flake8 $(check_dirs)
black --check $(check_dirs)
ruff $(check_dirs)
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
python utils/check_doc_toc.py
......@@ -57,8 +55,8 @@ extra_style_checks:
# this target runs checks on all files and potentially modifies some of them
style:
black --preview $(check_dirs)
isort $(check_dirs)
black $(check_dirs)
ruff $(check_dirs) --fix
${MAKE} autogenerate_code
${MAKE} extra_style_checks
......
......@@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi
$ make style
```
🧨 Diffusers also uses `flake8` and a few custom scripts to check for coding mistakes. Quality
🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
control runs in CI, however you can also run the same checks with:
```bash
......
......@@ -210,6 +210,7 @@ torch.set_grad_enabled(False)
n_experiments = 2
unet_runs_per_experiment = 50
# load inputs
def generate_inputs():
sample = torch.randn(2, 4, 64, 64).half().cuda()
......@@ -288,6 +289,8 @@ pipe = StableDiffusionPipeline.from_pretrained(
# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")
# del pipe.unet
class TracedUNet(torch.nn.Module):
def __init__(self):
......
from typing import Optional, Tuple, Union
import torch
from einops import rearrange, reduce
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
from einops import rearrange, reduce
BITS = 8
......
......@@ -10,10 +10,11 @@ from diffusers.utils import is_safetensors_available
if is_safetensors_available():
import safetensors.torch
from huggingface_hub import snapshot_download
from diffusers import DiffusionPipeline, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
from huggingface_hub import snapshot_download
class CheckpointMergerPipeline(DiffusionPipeline):
......
......@@ -4,6 +4,8 @@ from typing import List, Optional, Union
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
......@@ -14,8 +16,6 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
class MakeCutouts(nn.Module):
......
......@@ -16,6 +16,8 @@ import inspect
from typing import Callable, List, Optional, Union
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
......@@ -29,8 +31,6 @@ from diffusers.schedulers import (
PNDMScheduler,
)
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
......
......@@ -7,11 +7,16 @@ import warnings
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
import torch.nn.functional as F
import PIL
from accelerate import Accelerator
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
......@@ -19,11 +24,6 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
......
......@@ -2,9 +2,10 @@ import inspect
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import PIL
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
......@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
......@@ -13,7 +14,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -3,16 +3,16 @@ import re
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import diffusers
import PIL
from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
try:
......
......@@ -3,15 +3,15 @@ import re
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
import diffusers
import PIL
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
try:
......
from typing import Union
import torch
from PIL import Image
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
......@@ -10,10 +14,6 @@ from diffusers import (
PNDMScheduler,
UNet2DConditionModel,
)
from PIL import Image
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
class MagicMixPipeline(DiffusionPipeline):
......
......@@ -2,14 +2,6 @@ import inspect
from typing import Callable, List, Optional, Union
import torch
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
......@@ -19,6 +11,14 @@ from transformers import (
pipeline,
)
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -17,11 +17,11 @@ import warnings
from typing import Callable, List, Optional, Union
import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import is_accelerate_available, logging
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -5,6 +5,7 @@ import inspect
from typing import Callable, List, Optional, Union
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
......@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -2,6 +2,13 @@ import inspect
from typing import Callable, List, Optional, Union
import torch
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from diffusers import (
AutoencoderKL,
......@@ -14,13 +21,6 @@ from diffusers import (
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.utils import logging
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
WhisperForConditionalGeneration,
WhisperProcessor,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment