Unverified Commit 522f8aa7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Black] Update black library (#2007)

parent 8a3f0c1f
...@@ -336,7 +336,7 @@ class DreamBoothDataset(Dataset): ...@@ -336,7 +336,7 @@ class DreamBoothDataset(Dataset):
self.instance_data_root = Path(instance_data_root) self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists(): if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.") raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir()) self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path) self.num_instance_images = len(self.instance_images_path)
......
...@@ -336,7 +336,10 @@ class TextualInversionDataset(Dataset): ...@@ -336,7 +336,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
(h, w,) = ( (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -381,7 +381,10 @@ class TextualInversionDataset(Dataset): ...@@ -381,7 +381,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
(h, w,) = ( (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -306,7 +306,10 @@ class TextualInversionDataset(Dataset): ...@@ -306,7 +306,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
(h, w,) = ( (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -80,7 +80,7 @@ from setuptools import find_packages, setup ...@@ -80,7 +80,7 @@ from setuptools import find_packages, setup
_deps = [ _deps = [
"Pillow", # keep the PIL.Image.Resampling deprecation away "Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0", "accelerate>=0.11.0",
"black==22.8", "black==22.12",
"datasets", "datasets",
"filelock", "filelock",
"flake8>=3.8.3", "flake8>=3.8.3",
......
...@@ -90,8 +90,10 @@ class AttentionBlock(nn.Module): ...@@ -90,8 +90,10 @@ class AttentionBlock(nn.Module):
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if not is_xformers_available(): if not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" (
" xformers", "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers", name="xformers",
) )
elif not torch.cuda.is_available(): elif not torch.cuda.is_available():
......
...@@ -105,8 +105,10 @@ class CrossAttention(nn.Module): ...@@ -105,8 +105,10 @@ class CrossAttention(nn.Module):
) )
elif not is_xformers_available(): elif not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" (
" xformers", "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers", name="xformers",
) )
elif not torch.cuda.is_available(): elif not torch.cuda.is_available():
......
...@@ -189,9 +189,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,9 +189,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
or isinstance(timestep, torch.LongTensor) or isinstance(timestep, torch.LongTensor)
): ):
raise ValueError( raise ValueError(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" (
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" one of the `scheduler.timesteps` as a timestep.", " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
) )
if not self.is_scale_input_called: if not self.is_scale_input_called:
......
...@@ -198,9 +198,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,9 +198,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
or isinstance(timestep, torch.LongTensor) or isinstance(timestep, torch.LongTensor)
): ):
raise ValueError( raise ValueError(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" (
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" one of the `scheduler.timesteps` as a timestep.", " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
) )
if not self.is_scale_input_called: if not self.is_scale_input_called:
......
...@@ -537,8 +537,10 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -537,8 +537,10 @@ class SchedulerCommonTest(unittest.TestCase):
) )
self.assertTrue( self.assertTrue(
hasattr(scheduler, "scale_model_input"), hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample," (
" timestep)`", f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
" timestep)`"
),
) )
self.assertTrue( self.assertTrue(
hasattr(scheduler, "step"), hasattr(scheduler, "step"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment