"src/vscode:/vscode.git/clone" did not exist on "16d500455b7ba6c92e1b7ddfbb6e41edbb3eff5e"
Commit 97ef5e06 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent 31be4220
...@@ -336,10 +336,7 @@ class TextualInversionDataset(Dataset): ...@@ -336,10 +336,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
( (h, w,) = (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -432,10 +432,7 @@ class TextualInversionDataset(Dataset): ...@@ -432,10 +432,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
( (h, w,) = (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -306,10 +306,7 @@ class TextualInversionDataset(Dataset): ...@@ -306,10 +306,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
( (h, w,) = (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -94,10 +94,8 @@ class AttentionBlock(nn.Module): ...@@ -94,10 +94,8 @@ class AttentionBlock(nn.Module):
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if not is_xformers_available(): if not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers" " xformers",
),
name="xformers", name="xformers",
) )
elif not torch.cuda.is_available(): elif not torch.cuda.is_available():
......
...@@ -111,10 +111,8 @@ class CrossAttention(nn.Module): ...@@ -111,10 +111,8 @@ class CrossAttention(nn.Module):
) )
elif not is_xformers_available(): elif not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers" " xformers",
),
name="xformers", name="xformers",
) )
elif not torch.cuda.is_available(): elif not torch.cuda.is_available():
......
...@@ -189,11 +189,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,11 +189,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
or isinstance(timestep, torch.LongTensor) or isinstance(timestep, torch.LongTensor)
): ):
raise ValueError( raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep." " one of the `scheduler.timesteps` as a timestep.",
),
) )
if not self.is_scale_input_called: if not self.is_scale_input_called:
......
...@@ -198,11 +198,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,11 +198,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
or isinstance(timestep, torch.LongTensor) or isinstance(timestep, torch.LongTensor)
): ):
raise ValueError( raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep." " one of the `scheduler.timesteps` as a timestep.",
),
) )
if not self.is_scale_input_called: if not self.is_scale_input_called:
......
...@@ -537,10 +537,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -537,10 +537,8 @@ class SchedulerCommonTest(unittest.TestCase):
) )
self.assertTrue( self.assertTrue(
hasattr(scheduler, "scale_model_input"), hasattr(scheduler, "scale_model_input"),
(
f"{scheduler_class} does not implement a required class method `scale_model_input(sample," f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
" timestep)`" " timestep)`",
),
) )
self.assertTrue( self.assertTrue(
hasattr(scheduler, "step"), hasattr(scheduler, "step"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment