Unverified Commit 6b04d61c authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Styling] stylify using ruff (#5841)

* ruff format

* not need to use doc-builder's black styling as the doc is styled in ruff

* make fix-copies

* comment

* use run_ruff
parent 9c7f7fc4
[tool.black]
line-length = 119
target-version = ['py37']
[tool.ruff] [tool.ruff]
# Never enforce `E501` (line length violations). # Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "W605"] ignore = ["C901", "E501", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
line-length = 119 line-length = 119
...@@ -16,3 +12,16 @@ line-length = 119 ...@@ -16,3 +12,16 @@ line-length = 119
[tool.ruff.isort] [tool.ruff.isort]
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["diffusers"] known-first-party = ["diffusers"]
[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
...@@ -11,7 +11,7 @@ from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel ...@@ -11,7 +11,7 @@ from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
""" r"""
Example - From the diffusers root directory: Example - From the diffusers root directory:
Download weights: Download weights:
......
[isort]
default_section = FIRSTPARTY
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = accelerate
known_third_party =
numpy
torch
torch_xla
line_length = 119
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True
[flake8]
ignore = E203, E722, E501, E741, W503, W605
max-line-length = 119
per-file-ignores = __init__.py:F401
...@@ -78,9 +78,9 @@ To create the package for PyPI. ...@@ -78,9 +78,9 @@ To create the package for PyPI.
you need to go back to main before executing this. you need to go back to main before executing this.
""" """
import sys
import os import os
import re import re
import sys
from distutils.core import Command from distutils.core import Command
from setuptools import find_packages, setup from setuptools import find_packages, setup
...@@ -93,7 +93,6 @@ _deps = [ ...@@ -93,7 +93,6 @@ _deps = [
"Pillow", # keep the PIL.Image.Resampling deprecation away "Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0", "accelerate>=0.11.0",
"compel==0.1.8", "compel==0.1.8",
"black~=23.1",
"datasets", "datasets",
"filelock", "filelock",
"flax>=0.4.1", "flax>=0.4.1",
...@@ -119,7 +118,7 @@ _deps = [ ...@@ -119,7 +118,7 @@ _deps = [
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"python>=3.8.0", "python>=3.8.0",
"ruff==0.0.280", "ruff>=0.1.5,<=0.2",
"safetensors>=0.3.1", "safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",
"scipy", "scipy",
...@@ -171,7 +170,11 @@ class DepsTableUpdateCommand(Command): ...@@ -171,7 +170,11 @@ class DepsTableUpdateCommand(Command):
description = "build runtime dependency table" description = "build runtime dependency table"
user_options = [ user_options = [
# format: (long option, short option, description). # format: (long option, short option, description).
("dep-table-update", None, "updates src/diffusers/dependency_versions_table.py"), (
"dep-table-update",
None,
"updates src/diffusers/dependency_versions_table.py",
),
] ]
def initialize_options(self): def initialize_options(self):
...@@ -197,10 +200,8 @@ class DepsTableUpdateCommand(Command): ...@@ -197,10 +200,8 @@ class DepsTableUpdateCommand(Command):
f.write("\n".join(content)) f.write("\n".join(content))
extras = {} extras = {}
extras["quality"] = deps_list("urllib3", "black", "isort", "ruff", "hf-doc-builder") extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2") extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["test"] = deps_list( extras["test"] = deps_list(
...@@ -275,10 +276,7 @@ setup( ...@@ -275,10 +276,7 @@ setup(
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
] ]
+ [ + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
f"Programming Language :: Python :: 3.{i}"
for i in range(8, version_range_max)
],
cmdclass={"deps_table_update": DepsTableUpdateCommand}, cmdclass={"deps_table_update": DepsTableUpdateCommand},
) )
......
...@@ -95,6 +95,7 @@ class ConfigMixin: ...@@ -95,6 +95,7 @@ class ConfigMixin:
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass). subclass).
""" """
config_name = None config_name = None
ignore_for_config = [] ignore_for_config = []
has_compatibles = False has_compatibles = False
......
...@@ -5,7 +5,6 @@ deps = { ...@@ -5,7 +5,6 @@ deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0", "accelerate": "accelerate>=0.11.0",
"compel": "compel==0.1.8", "compel": "compel==0.1.8",
"black": "black~=23.1",
"datasets": "datasets", "datasets": "datasets",
"filelock": "filelock", "filelock": "filelock",
"flax": "flax>=0.4.1", "flax": "flax>=0.4.1",
...@@ -31,7 +30,7 @@ deps = { ...@@ -31,7 +30,7 @@ deps = {
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0", "python": "python>=3.8.0",
"ruff": "ruff==0.0.280", "ruff": "ruff>=0.1.5,<=0.2",
"safetensors": "safetensors>=0.3.1", "safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy", "scipy": "scipy",
......
...@@ -71,6 +71,7 @@ class LoraLoaderMixin: ...@@ -71,6 +71,7 @@ class LoraLoaderMixin:
Load LoRA layers into [`UNet2DConditionModel`] and Load LoRA layers into [`UNet2DConditionModel`] and
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
""" """
text_encoder_name = TEXT_ENCODER_NAME text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME unet_name = UNET_NAME
num_fused_loras = 0 num_fused_loras = 0
......
...@@ -110,7 +110,10 @@ def jax_memory_efficient_attention( ...@@ -110,7 +110,10 @@ def jax_memory_efficient_attention(
) )
_, res = jax.lax.scan( _, res = jax.lax.scan(
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter f=chunk_scanner,
init=0,
xs=None,
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
) )
return jnp.concatenate(res, axis=-3) # fuse the chunked result back return jnp.concatenate(res, axis=-3) # fuse the chunked result back
...@@ -138,6 +141,7 @@ class FlaxAttention(nn.Module): ...@@ -138,6 +141,7 @@ class FlaxAttention(nn.Module):
Parameters `dtype` Parameters `dtype`
""" """
query_dim: int query_dim: int
heads: int = 8 heads: int = 8
dim_head: int = 64 dim_head: int = 64
...@@ -262,6 +266,7 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -262,6 +266,7 @@ class FlaxBasicTransformerBlock(nn.Module):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
""" """
dim: int dim: int
n_heads: int n_heads: int
d_head: int d_head: int
...@@ -347,6 +352,7 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -347,6 +352,7 @@ class FlaxTransformer2DModel(nn.Module):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
""" """
in_channels: int in_channels: int
n_heads: int n_heads: int
d_head: int d_head: int
...@@ -442,6 +448,7 @@ class FlaxFeedForward(nn.Module): ...@@ -442,6 +448,7 @@ class FlaxFeedForward(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
dim: int dim: int
dropout: float = 0.0 dropout: float = 0.0
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -471,6 +478,7 @@ class FlaxGEGLU(nn.Module): ...@@ -471,6 +478,7 @@ class FlaxGEGLU(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
dim: int dim: int
dropout: float = 0.0 dropout: float = 0.0
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
......
...@@ -91,6 +91,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -91,6 +91,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
`force_upcast` can be set to `False` (see this fp16-friendly `force_upcast` can be set to `False` (see this fp16-friendly
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
......
...@@ -146,6 +146,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -146,6 +146,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer. The tuple of output channel for each block in the `conditioning_embedding` layer.
""" """
sample_size: int = 32 sample_size: int = 32
in_channels: int = 4 in_channels: int = 4
down_block_types: Tuple[str, ...] = ( down_block_types: Tuple[str, ...] = (
......
...@@ -65,6 +65,7 @@ class FlaxTimestepEmbedding(nn.Module): ...@@ -65,6 +65,7 @@ class FlaxTimestepEmbedding(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
time_embed_dim: int = 32 time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -84,6 +85,7 @@ class FlaxTimesteps(nn.Module): ...@@ -84,6 +85,7 @@ class FlaxTimesteps(nn.Module):
dim (`int`, *optional*, defaults to `32`): dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension Time step embedding dimension
""" """
dim: int = 32 dim: int = 32
flip_sin_to_cos: bool = False flip_sin_to_cos: bool = False
freq_shift: float = 1 freq_shift: float = 1
......
...@@ -52,6 +52,7 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -52,6 +52,7 @@ class FlaxModelMixin(PushToHubMixin):
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_flax_internal_args = ["name", "parent", "dtype"] _flax_internal_args = ["name", "parent", "dtype"]
......
...@@ -193,6 +193,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -193,6 +193,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
......
...@@ -45,6 +45,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -45,6 +45,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -125,6 +126,7 @@ class FlaxDownBlock2D(nn.Module): ...@@ -125,6 +126,7 @@ class FlaxDownBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -190,6 +192,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -190,6 +192,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
out_channels: int out_channels: int
prev_output_channel: int prev_output_channel: int
...@@ -275,6 +278,7 @@ class FlaxUpBlock2D(nn.Module): ...@@ -275,6 +278,7 @@ class FlaxUpBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
out_channels: int out_channels: int
prev_output_channel: int prev_output_channel: int
...@@ -339,6 +343,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -339,6 +343,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
......
...@@ -174,6 +174,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -174,6 +174,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving). for all models (such as downloading or saving).
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
......
...@@ -214,6 +214,7 @@ class FlaxAttentionBlock(nn.Module): ...@@ -214,6 +214,7 @@ class FlaxAttentionBlock(nn.Module):
Parameters `dtype` Parameters `dtype`
""" """
channels: int channels: int
num_head_channels: int = None num_head_channels: int = None
num_groups: int = 32 num_groups: int = 32
...@@ -291,6 +292,7 @@ class FlaxDownEncoderBlock2D(nn.Module): ...@@ -291,6 +292,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -347,6 +349,7 @@ class FlaxUpDecoderBlock2D(nn.Module): ...@@ -347,6 +349,7 @@ class FlaxUpDecoderBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
out_channels: int out_channels: int
dropout: float = 0.0 dropout: float = 0.0
...@@ -401,6 +404,7 @@ class FlaxUNetMidBlock2D(nn.Module): ...@@ -401,6 +404,7 @@ class FlaxUNetMidBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int in_channels: int
dropout: float = 0.0 dropout: float = 0.0
num_layers: int = 1 num_layers: int = 1
...@@ -488,6 +492,7 @@ class FlaxEncoder(nn.Module): ...@@ -488,6 +492,7 @@ class FlaxEncoder(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
...@@ -600,6 +605,7 @@ class FlaxDecoder(nn.Module): ...@@ -600,6 +605,7 @@ class FlaxDecoder(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype` parameters `dtype`
""" """
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
...@@ -767,6 +773,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -767,6 +773,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
The `dtype` of the parameters. The `dtype` of the parameters.
""" """
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
......
...@@ -243,10 +243,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -243,10 +243,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
**kwargs, **kwargs,
): ):
deprecation_message = ( deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
" instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
)
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt( prompt_embeds_tuple = self.encode_prompt(
...@@ -462,10 +459,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -462,10 +459,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
deprecation_message = ( deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
"The decode_latents method is deprecated and will be removed in 1.0.0. Please use"
" VaeImageProcessor.postprocess(...) instead"
)
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
...@@ -515,8 +509,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -515,8 +509,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
): ):
raise ValueError( raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found" f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
) )
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
...@@ -747,15 +740,13 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -747,15 +740,13 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
deprecate( deprecate(
"callback", "callback",
"1.0.0", "1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using" "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
" `callback_on_step_end`",
) )
if callback_steps is not None: if callback_steps is not None:
deprecate( deprecate(
"callback_steps", "callback_steps",
"1.0.0", "1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using" "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
" `callback_on_step_end`",
) )
# 0. Default height and width to unet # 0. Default height and width to unet
......
...@@ -252,10 +252,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -252,10 +252,7 @@ class AltDiffusionImg2ImgPipeline(
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
**kwargs, **kwargs,
): ):
deprecation_message = ( deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
" instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
)
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt( prompt_embeds_tuple = self.encode_prompt(
...@@ -471,10 +468,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -471,10 +468,7 @@ class AltDiffusionImg2ImgPipeline(
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
deprecation_message = ( deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
"The decode_latents method is deprecated and will be removed in 1.0.0. Please use"
" VaeImageProcessor.postprocess(...) instead"
)
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
...@@ -524,8 +518,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -524,8 +518,7 @@ class AltDiffusionImg2ImgPipeline(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
): ):
raise ValueError( raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found" f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
) )
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
...@@ -578,8 +571,8 @@ class AltDiffusionImg2ImgPipeline( ...@@ -578,8 +571,8 @@ class AltDiffusionImg2ImgPipeline(
else: else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
elif isinstance(generator, list): elif isinstance(generator, list):
...@@ -798,15 +791,13 @@ class AltDiffusionImg2ImgPipeline( ...@@ -798,15 +791,13 @@ class AltDiffusionImg2ImgPipeline(
deprecate( deprecate(
"callback", "callback",
"1.0.0", "1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use" "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
" `callback_on_step_end`",
) )
if callback_steps is not None: if callback_steps is not None:
deprecate( deprecate(
"callback_steps", "callback_steps",
"1.0.0", "1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use" "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
" `callback_on_step_end`",
) )
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
......
...@@ -99,6 +99,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo ...@@ -99,6 +99,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__( def __init__(
......
...@@ -72,6 +72,7 @@ class AudioLDMPipeline(DiffusionPipeline): ...@@ -72,6 +72,7 @@ class AudioLDMPipeline(DiffusionPipeline):
vocoder ([`~transformers.SpeechT5HifiGan`]): vocoder ([`~transformers.SpeechT5HifiGan`]):
Vocoder of class `SpeechT5HifiGan`. Vocoder of class `SpeechT5HifiGan`.
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment