Unverified Commit a7ca03aa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Replace flake8 with ruff and update black (#2279)

* before running make style

* remove left overs from flake8

* finish

* make fix-copies

* final fix

* more fixes
parent f5ccffec
import random import random
import torch import torch
from huggingface_hub import HfApi
from diffusers import UNet2DModel from diffusers import UNet2DModel
from huggingface_hub import HfApi
api = HfApi() api = HfApi()
......
...@@ -80,10 +80,9 @@ from setuptools import find_packages, setup ...@@ -80,10 +80,9 @@ from setuptools import find_packages, setup
_deps = [ _deps = [
"Pillow", # keep the PIL.Image.Resampling deprecation away "Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0", "accelerate>=0.11.0",
"black==22.12", "black~=23.1",
"datasets", "datasets",
"filelock", "filelock",
"flake8>=3.8.3",
"flax>=0.4.1", "flax>=0.4.1",
"hf-doc-builder>=0.3.0", "hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0", "huggingface-hub>=0.10.0",
...@@ -99,6 +98,7 @@ _deps = [ ...@@ -99,6 +98,7 @@ _deps = [
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"ruff>=0.0.241",
"safetensors", "safetensors",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",
"scipy", "scipy",
...@@ -178,7 +178,7 @@ extras = {} ...@@ -178,7 +178,7 @@ extras = {}
extras = {} extras = {}
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder") extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
extras["test"] = deps_list( extras["test"] = deps_list(
......
...@@ -26,7 +26,6 @@ from pathlib import PosixPath ...@@ -26,7 +26,6 @@ from pathlib import PosixPath
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
......
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0", "accelerate": "accelerate>=0.11.0",
"black": "black==22.12", "black": "black~=23.1",
"datasets": "datasets", "datasets": "datasets",
"filelock": "filelock", "filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.4.1", "flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0", "hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0", "huggingface-hub": "huggingface-hub>=0.10.0",
...@@ -23,6 +22,7 @@ deps = { ...@@ -23,6 +22,7 @@ deps = {
"pytest": "pytest", "pytest": "pytest",
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
"ruff": "ruff>=0.0.241",
"safetensors": "safetensors", "safetensors": "safetensors",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy", "scipy": "scipy",
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import numpy as np import numpy as np
import torch import torch
import tqdm import tqdm
from ...models.unet_1d import UNet1DModel from ...models.unet_1d import UNet1DModel
...@@ -57,13 +56,13 @@ class ValueGuidedRLPipeline(DiffusionPipeline): ...@@ -57,13 +56,13 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
for key in self.data.keys(): for key in self.data.keys():
try: try:
self.means[key] = self.data[key].mean() self.means[key] = self.data[key].mean()
except: except: # noqa: E722
pass pass
self.stds = dict() self.stds = dict()
for key in self.data.keys(): for key in self.data.keys():
try: try:
self.stds[key] = self.data[key].std() self.stds[key] = self.data[key].std()
except: except: # noqa: E722
pass pass
self.state_dim = env.observation_space.shape[0] self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0] self.action_dim = env.action_space.shape[0]
......
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
from pickle import UnpicklingError from pickle import UnpicklingError
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from flax.serialization import from_bytes from flax.serialization import from_bytes
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict
......
...@@ -20,11 +20,10 @@ from functools import partial ...@@ -20,11 +20,10 @@ from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
from torch import Tensor, device
from .. import __version__ from .. import __version__
from ..utils import ( from ..utils import (
...@@ -500,7 +499,7 @@ class ModelMixin(torch.nn.Module): ...@@ -500,7 +499,7 @@ class ModelMixin(torch.nn.Module):
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
) )
except: except: # noqa: E722
pass pass
if model_file is None: if model_file is None:
model_file = _get_model_file( model_file = _get_model_file(
......
...@@ -2,7 +2,6 @@ from dataclasses import dataclass ...@@ -2,7 +2,6 @@ from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
from PIL import Image from PIL import Image
......
...@@ -3,7 +3,6 @@ from typing import Optional, Tuple ...@@ -3,7 +3,6 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
from transformers.utils import ModelOutput from transformers.utils import ModelOutput
......
...@@ -16,11 +16,11 @@ import inspect ...@@ -16,11 +16,11 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
......
...@@ -16,13 +16,13 @@ import inspect ...@@ -16,13 +16,13 @@ import inspect
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
......
# flake8: noqa
from .mel import Mel from .mel import Mel
from .pipeline_audio_diffusion import AudioDiffusionPipeline from .pipeline_audio_diffusion import AudioDiffusionPipeline
...@@ -18,7 +18,6 @@ from typing import List, Tuple, Union ...@@ -18,7 +18,6 @@ from typing import List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
......
# flake8: noqa
from .pipeline_dance_diffusion import DanceDiffusionPipeline from .pipeline_dance_diffusion import DanceDiffusionPipeline
# flake8: noqa
from .pipeline_ddim import DDIMPipeline from .pipeline_ddim import DDIMPipeline
# flake8: noqa
from .pipeline_ddpm import DDPMPipeline from .pipeline_ddpm import DDPMPipeline
# flake8: noqa
from ...utils import is_transformers_available from ...utils import is_transformers_available
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
......
...@@ -18,7 +18,6 @@ from typing import List, Optional, Tuple, Union ...@@ -18,7 +18,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
......
...@@ -2,11 +2,10 @@ import inspect ...@@ -2,11 +2,10 @@ import inspect
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import PIL
from ...models import UNet2DModel, VQModel from ...models import UNet2DModel, VQModel
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
......
# flake8: noqa
from .pipeline_latent_diffusion_uncond import LDMPipeline from .pipeline_latent_diffusion_uncond import LDMPipeline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment