Unverified Commit 6b04d61c authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Styling] stylify using ruff (#5841)

* ruff format

* not need to use doc-builder's black styling as the doc is styled in ruff

* make fix-copies

* comment

* use run_ruff
parent 9c7f7fc4
...@@ -27,9 +27,8 @@ jobs: ...@@ -27,9 +27,8 @@ jobs:
pip install .[quality] pip install .[quality]
- name: Check quality - name: Check quality
run: | run: |
black --check examples tests src utils scripts ruff check examples tests src utils scripts
ruff examples tests src utils scripts ruff format examples tests src utils scripts --check
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
check_repository_consistency: check_repository_consistency:
runs-on: ubuntu-latest runs-on: ubuntu-latest
......
...@@ -410,7 +410,7 @@ Diffusers has grown a lot. Here is the command for it: ...@@ -410,7 +410,7 @@ Diffusers has grown a lot. Here is the command for it:
$ make test $ make test
``` ```
🧨 Diffusers relies on `black` and `isort` to format its source code 🧨 Diffusers relies on `ruff` and `isort` to format its source code
consistently. After you make changes, apply automatic style corrections and code verifications consistently. After you make changes, apply automatic style corrections and code verifications
that can't be automated in one go with: that can't be automated in one go with:
......
...@@ -9,8 +9,8 @@ modified_only_fixup: ...@@ -9,8 +9,8 @@ modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
@if test -n "$(modified_py_files)"; then \ @if test -n "$(modified_py_files)"; then \
echo "Checking/fixing $(modified_py_files)"; \ echo "Checking/fixing $(modified_py_files)"; \
black $(modified_py_files); \ ruff check $(modified_py_files) --fix; \
ruff $(modified_py_files); \ ruff format $(modified_py_files);\
else \ else \
echo "No library .py files were modified"; \ echo "No library .py files were modified"; \
fi fi
...@@ -40,23 +40,21 @@ repo-consistency: ...@@ -40,23 +40,21 @@ repo-consistency:
# this target runs checks on all files # this target runs checks on all files
quality: quality:
black --check $(check_dirs) ruff check $(check_dirs) setup.py
ruff $(check_dirs) ruff format --check $(check_dirs) setup.py
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
python utils/check_doc_toc.py python utils/check_doc_toc.py
# Format source code automatically and check is there are any problems left that need manual fixing # Format source code automatically and check is there are any problems left that need manual fixing
extra_style_checks: extra_style_checks:
python utils/custom_init_isort.py python utils/custom_init_isort.py
doc-builder style src/diffusers docs/source --max_len 119 --path_to_docs docs/source
python utils/check_doc_toc.py --fix_and_overwrite python utils/check_doc_toc.py --fix_and_overwrite
# this target runs checks on all files and potentially modifies some of them # this target runs checks on all files and potentially modifies some of them
style: style:
black $(check_dirs) ruff check $(check_dirs) setup.py --fix
ruff $(check_dirs) --fix ruff format $(check_dirs) setup.py
${MAKE} autogenerate_code ${MAKE} autogenerate_code
${MAKE} extra_style_checks ${MAKE} extra_style_checks
......
...@@ -65,6 +65,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): ...@@ -65,6 +65,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
......
...@@ -564,9 +564,7 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin): ...@@ -564,9 +564,7 @@ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = ( self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
......
...@@ -469,9 +469,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -469,9 +469,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = ( self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
......
...@@ -56,10 +56,10 @@ def parse_prompt_attention(text): ...@@ -56,10 +56,10 @@ def parse_prompt_attention(text):
(abc) - increases attention to abc by a multiplier of 1.1 (abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12 (abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1 [abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '(' \\( - literal character '('
\[ - literal character '[' \\[ - literal character '['
\) - literal character ')' \\) - literal character ')'
\] - literal character ']' \\] - literal character ']'
\\ - literal character '\' \\ - literal character '\'
anything else - just text anything else - just text
>>> parse_prompt_attention('normal text') >>> parse_prompt_attention('normal text')
...@@ -68,7 +68,7 @@ def parse_prompt_attention(text): ...@@ -68,7 +68,7 @@ def parse_prompt_attention(text):
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced') >>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]] [['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]') >>> parse_prompt_attention('\\(literal\\]')
[['(literal]', 1.0]] [['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)') >>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]] [['unnecessaryparens', 1.1]]
......
...@@ -82,10 +82,10 @@ def parse_prompt_attention(text): ...@@ -82,10 +82,10 @@ def parse_prompt_attention(text):
(abc) - increases attention to abc by a multiplier of 1.1 (abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12 (abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1 [abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '(' \\( - literal character '('
\[ - literal character '[' \\[ - literal character '['
\) - literal character ')' \\) - literal character ')'
\] - literal character ']' \\] - literal character ']'
\\ - literal character '\' \\ - literal character '\'
anything else - just text anything else - just text
>>> parse_prompt_attention('normal text') >>> parse_prompt_attention('normal text')
...@@ -94,7 +94,7 @@ def parse_prompt_attention(text): ...@@ -94,7 +94,7 @@ def parse_prompt_attention(text):
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced') >>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]] [['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]') >>> parse_prompt_attention('\\(literal\\]')
[['(literal]', 1.0]] [['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)') >>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]] [['unnecessaryparens', 1.1]]
...@@ -433,6 +433,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline ...@@ -433,6 +433,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
""" """
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
def __init__( def __init__(
......
...@@ -46,10 +46,10 @@ def parse_prompt_attention(text): ...@@ -46,10 +46,10 @@ def parse_prompt_attention(text):
(abc) - increases attention to abc by a multiplier of 1.1 (abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12 (abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1 [abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '(' \\( - literal character '('
\[ - literal character '[' \\[ - literal character '['
\) - literal character ')' \\) - literal character ')'
\] - literal character ']' \\] - literal character ']'
\\ - literal character '\' \\ - literal character '\'
anything else - just text anything else - just text
...@@ -59,7 +59,7 @@ def parse_prompt_attention(text): ...@@ -59,7 +59,7 @@ def parse_prompt_attention(text):
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced') >>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]] [['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]') >>> parse_prompt_attention('\\(literal\\]')
[['(literal]', 1.0]] [['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)') >>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]] [['unnecessaryparens', 1.1]]
......
...@@ -127,9 +127,9 @@ class MagicMixPipeline(DiffusionPipeline): ...@@ -127,9 +127,9 @@ class MagicMixPipeline(DiffusionPipeline):
timesteps=t, timesteps=t,
) )
input = (mix_factor * latents) + ( input = (
1 - mix_factor (mix_factor * latents) + (1 - mix_factor) * orig_latents
) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics ) # interpolating between layout noise and conditionally generated noise to preserve layout sematics
input = torch.cat([input] * 2) input = torch.cat([input] * 2)
else: # content generation phase else: # content generation phase
......
...@@ -453,9 +453,7 @@ class StableDiffusionCanvasPipeline(DiffusionPipeline): ...@@ -453,9 +453,7 @@ class StableDiffusionCanvasPipeline(DiffusionPipeline):
:, :,
region.latent_row_init : region.latent_row_end, region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end, region.latent_col_init : region.latent_col_end,
] += ( ] += noise_pred_region * mask_weights_region
noise_pred_region * mask_weights_region
)
contributors[ contributors[
:, :,
:, :,
......
...@@ -65,6 +65,7 @@ class Prompt2PromptPipeline(StableDiffusionPipeline): ...@@ -65,6 +65,7 @@ class Prompt2PromptPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
@torch.no_grad() @torch.no_grad()
......
...@@ -94,6 +94,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline): ...@@ -94,6 +94,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline):
cc_projection ([`CCProjection`]): cc_projection ([`CCProjection`]):
Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size. Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
...@@ -658,7 +659,8 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline): ...@@ -658,7 +659,8 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline):
if isinstance(generator, list): if isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i]) for i in range(batch_size) # sample self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i])
for i in range(batch_size) # sample
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
......
...@@ -651,9 +651,10 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -651,9 +651,10 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = num_controlnet mult = num_controlnet
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_start, control_guidance_end = (
control_guidance_end mult * [control_guidance_start],
] mult * [control_guidance_end],
)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
......
...@@ -755,9 +755,10 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -755,9 +755,10 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = num_controlnet mult = num_controlnet
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_start, control_guidance_end = (
control_guidance_end mult * [control_guidance_start],
] mult * [control_guidance_end],
)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
......
...@@ -68,6 +68,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -68,6 +68,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
......
...@@ -89,6 +89,7 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -89,6 +89,7 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
......
...@@ -50,6 +50,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): ...@@ -50,6 +50,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
......
...@@ -170,6 +170,7 @@ class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -170,6 +170,7 @@ class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
feature_extractor ([`CLIPImageProcessor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
......
...@@ -464,9 +464,7 @@ def main(args): ...@@ -464,9 +464,7 @@ def main(args):
unet = gemini_zero_dpp(unet, args.placement) unet = gemini_zero_dpp(unet, args.placement)
# config optimizer for colossalai zero # config optimizer for colossalai zero
optimizer = GeminiAdamOptimizer( optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
)
# load noise_scheduler # load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment