Unverified Commit 2afea72d authored by Sai-Suraj-27's avatar Sai-Suraj-27 Committed by GitHub
Browse files

refactor: Refactored code by Merging `isinstance` calls (#7710)



* Merged isinstance calls to make the code simpler.

* Corrected formatting errors using ruff.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 0f111ab7
...@@ -460,7 +460,7 @@ class StableDiffusionUpscaleLDM3DPipeline( ...@@ -460,7 +460,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
) )
# verify batch size of prompt and image are same if image is a list or tensor or numpy array # verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): if isinstance(image, (list, np.ndarray, torch.Tensor)):
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -685,7 +685,7 @@ class UNet2DConditionModel( ...@@ -685,7 +685,7 @@ class UNet2DConditionModel(
positive_len = 768 positive_len = 768
if isinstance(cross_attention_dim, int): if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0] positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image" feature_type = "text-only" if attention_type == "gated" else "text-image"
......
...@@ -817,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -817,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
positive_len = 768 positive_len = 768
if isinstance(cross_attention_dim, int): if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0] positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image" feature_type = "text-only" if attention_type == "gated" else "text-image"
......
...@@ -197,7 +197,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -197,7 +197,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
) )
# verify batch size of prompt and image are same if image is a list or tensor or numpy array # verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, np.ndarray): if isinstance(image, (list, np.ndarray)):
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -221,7 +221,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -221,7 +221,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
) )
# verify batch size of prompt and image are same if image is a list or tensor # verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor): if isinstance(image, (list, torch.Tensor)):
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
else: else:
......
...@@ -468,7 +468,7 @@ class StableDiffusionUpscalePipeline( ...@@ -468,7 +468,7 @@ class StableDiffusionUpscalePipeline(
) )
# verify batch size of prompt and image are same if image is a list or tensor or numpy array # verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): if isinstance(image, (list, np.ndarray, torch.Tensor)):
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -185,7 +185,7 @@ def preprocess(image): ...@@ -185,7 +185,7 @@ def preprocess(image):
def preprocess_mask(mask, batch_size: int = 1): def preprocess_mask(mask, batch_size: int = 1):
if not isinstance(mask, torch.Tensor): if not isinstance(mask, torch.Tensor):
# preprocess mask # preprocess mask
if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask] mask = [mask]
if isinstance(mask, list): if isinstance(mask, list):
......
...@@ -347,11 +347,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -347,11 +347,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
otherwise a tuple is returned where the first element is the sample tensor. otherwise a tuple is returned where the first element is the sample tensor.
""" """
if ( if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError( raise ValueError(
( (
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
...@@ -310,11 +310,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -310,11 +310,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor. returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if ( if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError( raise ValueError(
( (
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
...@@ -375,11 +375,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -375,11 +375,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
if ( if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError( raise ValueError(
( (
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
...@@ -530,11 +530,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -530,11 +530,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor. returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if ( if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError( raise ValueError(
( (
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment