"references/vscode:/vscode.git/clone" did not exist on "bc3f8f6c134e545edb06dd830bba0683797d1a66"
Unverified Commit 2afea72d authored by Sai-Suraj-27's avatar Sai-Suraj-27 Committed by GitHub
Browse files

refactor: Refactored code by Merging `isinstance` calls (#7710)



* Merged isinstance calls to make the code simpler.

* Corrected formatting errors using ruff.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 0f111ab7
......@@ -460,7 +460,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(image, (list, np.ndarray, torch.Tensor)):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -685,7 +685,7 @@ class UNet2DConditionModel(
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"
......
......@@ -817,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"
......
......@@ -197,7 +197,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, np.ndarray):
if isinstance(image, (list, np.ndarray)):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -221,7 +221,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
)
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
if isinstance(image, (list, torch.Tensor)):
if isinstance(prompt, str):
batch_size = 1
else:
......
......@@ -468,7 +468,7 @@ class StableDiffusionUpscalePipeline(
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(image, (list, np.ndarray, torch.Tensor)):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -185,7 +185,7 @@ def preprocess(image):
def preprocess_mask(mask, batch_size: int = 1):
if not isinstance(mask, torch.Tensor):
# preprocess mask
if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray):
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list):
......
......@@ -347,11 +347,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
......@@ -310,11 +310,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
......@@ -375,11 +375,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
......@@ -530,11 +530,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment