"vscode:/vscode.git/clone" did not exist on "a4cacf13c2faa3fe12d6ad6d8a8b6cd4b067edbf"
Unverified Commit a7ca03aa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Replace flake8 with ruff and update black (#2279)

* before running make style

* remove left overs from flake8

* finish

* make fix-copies

* final fix

* more fixes
parent f5ccffec
...@@ -27,9 +27,8 @@ jobs: ...@@ -27,9 +27,8 @@ jobs:
pip install .[quality] pip install .[quality]
- name: Check quality - name: Check quality
run: | run: |
black --check --preview examples tests src utils scripts black --check examples tests src utils scripts
isort --check-only examples tests src utils scripts ruff examples tests src utils scripts
flake8 examples tests src utils scripts
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
check_repository_consistency: check_repository_consistency:
......
...@@ -169,3 +169,6 @@ tags ...@@ -169,3 +169,6 @@ tags
# dependencies # dependencies
/transformers /transformers
# ruff
.ruff_cache
...@@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi ...@@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi
$ make style $ make style
``` ```
🧨 Diffusers also uses `flake8` and a few custom scripts to check for coding mistakes. Quality 🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
control runs in CI, however you can also run the same checks with: control runs in CI, however you can also run the same checks with:
```bash ```bash
......
...@@ -9,9 +9,8 @@ modified_only_fixup: ...@@ -9,9 +9,8 @@ modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
@if test -n "$(modified_py_files)"; then \ @if test -n "$(modified_py_files)"; then \
echo "Checking/fixing $(modified_py_files)"; \ echo "Checking/fixing $(modified_py_files)"; \
black --preview $(modified_py_files); \ black $(modified_py_files); \
isort $(modified_py_files); \ ruff $(modified_py_files); \
flake8 $(modified_py_files); \
else \ else \
echo "No library .py files were modified"; \ echo "No library .py files were modified"; \
fi fi
...@@ -41,9 +40,8 @@ repo-consistency: ...@@ -41,9 +40,8 @@ repo-consistency:
# this target runs checks on all files # this target runs checks on all files
quality: quality:
black --check --preview $(check_dirs) black --check $(check_dirs)
isort --check-only $(check_dirs) ruff $(check_dirs)
flake8 $(check_dirs)
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
python utils/check_doc_toc.py python utils/check_doc_toc.py
...@@ -57,8 +55,8 @@ extra_style_checks: ...@@ -57,8 +55,8 @@ extra_style_checks:
# this target runs checks on all files and potentially modifies some of them # this target runs checks on all files and potentially modifies some of them
style: style:
black --preview $(check_dirs) black $(check_dirs)
isort $(check_dirs) ruff $(check_dirs) --fix
${MAKE} autogenerate_code ${MAKE} autogenerate_code
${MAKE} extra_style_checks ${MAKE} extra_style_checks
......
...@@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi ...@@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi
$ make style $ make style
``` ```
🧨 Diffusers also uses `flake8` and a few custom scripts to check for coding mistakes. Quality 🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
control runs in CI, however you can also run the same checks with: control runs in CI, however you can also run the same checks with:
```bash ```bash
......
...@@ -210,6 +210,7 @@ torch.set_grad_enabled(False) ...@@ -210,6 +210,7 @@ torch.set_grad_enabled(False)
n_experiments = 2 n_experiments = 2
unet_runs_per_experiment = 50 unet_runs_per_experiment = 50
# load inputs # load inputs
def generate_inputs(): def generate_inputs():
sample = torch.randn(2, 4, 64, 64).half().cuda() sample = torch.randn(2, 4, 64, 64).half().cuda()
...@@ -288,6 +289,8 @@ pipe = StableDiffusionPipeline.from_pretrained( ...@@ -288,6 +289,8 @@ pipe = StableDiffusionPipeline.from_pretrained(
# use jitted unet # use jitted unet
unet_traced = torch.jit.load("unet_traced.pt") unet_traced = torch.jit.load("unet_traced.pt")
# del pipe.unet # del pipe.unet
class TracedUNet(torch.nn.Module): class TracedUNet(torch.nn.Module):
def __init__(self): def __init__(self):
......
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from einops import rearrange, reduce
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
from einops import rearrange, reduce
BITS = 8 BITS = 8
......
...@@ -10,10 +10,11 @@ from diffusers.utils import is_safetensors_available ...@@ -10,10 +10,11 @@ from diffusers.utils import is_safetensors_available
if is_safetensors_available(): if is_safetensors_available():
import safetensors.torch import safetensors.torch
from huggingface_hub import snapshot_download
from diffusers import DiffusionPipeline, __version__ from diffusers import DiffusionPipeline, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
from huggingface_hub import snapshot_download
class CheckpointMergerPipeline(DiffusionPipeline): class CheckpointMergerPipeline(DiffusionPipeline):
......
...@@ -4,6 +4,8 @@ from typing import List, Optional, Union ...@@ -4,6 +4,8 @@ from typing import List, Optional, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -14,8 +16,6 @@ from diffusers import ( ...@@ -14,8 +16,6 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
class MakeCutouts(nn.Module): class MakeCutouts(nn.Module):
......
...@@ -16,6 +16,8 @@ import inspect ...@@ -16,6 +16,8 @@ import inspect
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
...@@ -29,8 +31,6 @@ from diffusers.schedulers import ( ...@@ -29,8 +31,6 @@ from diffusers.schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...utils import deprecate, logging from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
......
...@@ -7,11 +7,16 @@ import warnings ...@@ -7,11 +7,16 @@ import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import PIL
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import PIL
from accelerate import Accelerator from accelerate import Accelerator
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
...@@ -19,11 +24,6 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS ...@@ -19,11 +24,6 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging from diffusers.utils import deprecate, logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = { PIL_INTERPOLATION = {
......
...@@ -2,9 +2,10 @@ import inspect ...@@ -2,9 +2,10 @@ import inspect
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import PIL
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
...@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput ...@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging from diffusers.utils import deprecate, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Union ...@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
...@@ -13,7 +14,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput ...@@ -13,7 +14,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging from diffusers.utils import deprecate, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -3,16 +3,16 @@ import re ...@@ -3,16 +3,16 @@ import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL
import torch import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
import PIL
from diffusers import SchedulerMixin, StableDiffusionPipeline from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import deprecate, logging from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
try: try:
......
...@@ -3,15 +3,15 @@ import re ...@@ -3,15 +3,15 @@ import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL
import torch import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
import diffusers import diffusers
import PIL
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import deprecate, logging from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
try: try:
......
from typing import Union from typing import Union
import torch import torch
from PIL import Image
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -10,10 +14,6 @@ from diffusers import ( ...@@ -10,10 +14,6 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from PIL import Image
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
class MagicMixPipeline(DiffusionPipeline): class MagicMixPipeline(DiffusionPipeline):
......
...@@ -2,14 +2,6 @@ import inspect ...@@ -2,14 +2,6 @@ import inspect
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPFeatureExtractor,
CLIPTextModel, CLIPTextModel,
...@@ -19,6 +11,14 @@ from transformers import ( ...@@ -19,6 +11,14 @@ from transformers import (
pipeline, pipeline,
) )
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -17,11 +17,11 @@ import warnings ...@@ -17,11 +17,11 @@ import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from diffusers import DiffusionPipeline, LMSDiscreteScheduler from diffusers import DiffusionPipeline, LMSDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import is_accelerate_available, logging from diffusers.utils import is_accelerate_available, logging
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -5,6 +5,7 @@ import inspect ...@@ -5,6 +5,7 @@ import inspect
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
...@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput ...@@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging from diffusers.utils import logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -2,6 +2,13 @@ import inspect ...@@ -2,6 +2,13 @@ import inspect
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -14,13 +21,6 @@ from diffusers import ( ...@@ -14,13 +21,6 @@ from diffusers import (
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.utils import logging from diffusers.utils import logging
from transformers import (
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
WhisperForConditionalGeneration,
WhisperProcessor,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment