Unverified Commit 1d7b4b60 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Ruff: apply same rules as in transformers (#2827)

* Apply same ruff settings as in transformers

See https://github.com/huggingface/transformers/blob/main/pyproject.toml

Co-authored-by: default avatarAaron Gokaslan <aaronGokaslan@gmail.com>

* Apply new style rules

* Style
Co-authored-by: default avatarAaron Gokaslan <aaronGokaslan@gmail.com>

* style

* remove list, ruff wouldn't auto fix.

---------
Co-authored-by: default avatarAaron Gokaslan <aaronGokaslan@gmail.com>
parent abb22b4e
...@@ -491,7 +491,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -491,7 +491,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
parameters = inspect.signature(obj.__init__).parameters parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"]) expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters return expected_modules, optional_parameters
@property @property
......
...@@ -204,11 +204,11 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi ...@@ -204,11 +204,11 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}") non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
if variant is not None: if variant is not None:
variant_filenames = set(f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None) variant_filenames = {f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None}
else: else:
variant_filenames = set() variant_filenames = set()
non_variant_filenames = set(f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None) non_variant_filenames = {f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None}
usable_filenames = set(variant_filenames) usable_filenames = set(variant_filenames)
for f in non_variant_filenames: for f in non_variant_filenames:
...@@ -225,7 +225,7 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, ...@@ -225,7 +225,7 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=None, revision=None,
) )
filenames = set(sibling.rfilename for sibling in info.siblings) filenames = {sibling.rfilename for sibling in info.siblings}
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
...@@ -1115,7 +1115,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1115,7 +1115,7 @@ class DiffusionPipeline(ConfigMixin):
# retrieve all folder_names that contain relevant files # retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
filenames = set(sibling.rfilename for sibling in info.siblings) filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
# if the whole pipeline is cached we don't have to ping the Hub # if the whole pipeline is cached we don't have to ping the Hub
...@@ -1126,7 +1126,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1126,7 +1126,7 @@ class DiffusionPipeline(ConfigMixin):
pretrained_model_name, use_auth_token, variant, revision, model_filenames pretrained_model_name, use_auth_token, variant, revision, model_filenames
) )
model_folder_names = set([os.path.split(f)[0] for f in model_filenames]) model_folder_names = {os.path.split(f)[0] for f in model_filenames}
# all filenames compatible with variant will be added # all filenames compatible with variant will be added
allow_patterns = list(model_filenames) allow_patterns = list(model_filenames)
...@@ -1157,8 +1157,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1157,8 +1157,8 @@ class DiffusionPipeline(ConfigMixin):
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant): elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"] ignore_patterns = ["*.bin", "*.msgpack"]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")]) safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")]) safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if ( if (
len(safetensors_variant_filenames) > 0 len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames and safetensors_model_filenames != safetensors_variant_filenames
...@@ -1169,8 +1169,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1169,8 +1169,8 @@ class DiffusionPipeline(ConfigMixin):
else: else:
ignore_patterns = ["*.safetensors", "*.msgpack"] ignore_patterns = ["*.safetensors", "*.msgpack"]
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")]) bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")]) bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warn( logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
...@@ -1215,7 +1215,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1215,7 +1215,7 @@ class DiffusionPipeline(ConfigMixin):
parameters = inspect.signature(obj.__init__).parameters parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"]) expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters return expected_modules, optional_parameters
@property @property
......
...@@ -37,7 +37,7 @@ def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): ...@@ -37,7 +37,7 @@ def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
...@@ -58,7 +58,7 @@ def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): ...@@ -58,7 +58,7 @@ def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
if isinstance(mask[0], PIL.Image.Image): if isinstance(mask[0], PIL.Image.Image):
w, h = mask[0].size w, h = mask[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
mask = np.concatenate(mask, axis=0) mask = np.concatenate(mask, axis=0)
mask = mask.astype(np.float32) / 255.0 mask = mask.astype(np.float32) / 255.0
......
...@@ -166,7 +166,7 @@ class Codec: ...@@ -166,7 +166,7 @@ class Codec:
self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps)
self._event_ranges = [self._shift_range] + event_ranges self._event_ranges = [self._shift_range] + event_ranges
# Ensure all event types have unique names. # Ensure all event types have unique names.
assert len(self._event_ranges) == len(set([er.type for er in self._event_ranges])) assert len(self._event_ranges) == len({er.type for er in self._event_ranges})
@property @property
def num_classes(self) -> int: def num_classes(self) -> int:
......
...@@ -274,18 +274,18 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -274,18 +274,18 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
else: else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
config = dict( config = {
sample_size=image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
in_channels=unet_params.in_channels, "in_channels": unet_params.in_channels,
down_block_types=tuple(down_block_types), "down_block_types": tuple(down_block_types),
block_out_channels=tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
layers_per_block=unet_params.num_res_blocks, "layers_per_block": unet_params.num_res_blocks,
cross_attention_dim=unet_params.context_dim, "cross_attention_dim": unet_params.context_dim,
attention_head_dim=head_dim, "attention_head_dim": head_dim,
use_linear_projection=use_linear_projection, "use_linear_projection": use_linear_projection,
class_embed_type=class_embed_type, "class_embed_type": class_embed_type,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
) }
if not controlnet: if not controlnet:
config["out_channels"] = unet_params.out_channels config["out_channels"] = unet_params.out_channels
...@@ -305,16 +305,16 @@ def create_vae_diffusers_config(original_config, image_size: int): ...@@ -305,16 +305,16 @@ def create_vae_diffusers_config(original_config, image_size: int):
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict( config = {
sample_size=image_size, "sample_size": image_size,
in_channels=vae_params.in_channels, "in_channels": vae_params.in_channels,
out_channels=vae_params.out_ch, "out_channels": vae_params.out_ch,
down_block_types=tuple(down_block_types), "down_block_types": tuple(down_block_types),
up_block_types=tuple(up_block_types), "up_block_types": tuple(up_block_types),
block_out_channels=tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
latent_channels=vae_params.z_channels, "latent_channels": vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks, "layers_per_block": vae_params.num_res_blocks,
) }
return config return config
......
...@@ -44,7 +44,7 @@ def preprocess(image): ...@@ -44,7 +44,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -530,7 +530,7 @@ def unshard(x: jnp.ndarray): ...@@ -530,7 +530,7 @@ def unshard(x: jnp.ndarray):
def preprocess(image, dtype): def preprocess(image, dtype):
image = image.convert("RGB") image = image.convert("RGB")
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = jnp.array(image).astype(dtype) / 255.0 image = jnp.array(image).astype(dtype) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
......
...@@ -520,7 +520,7 @@ def unshard(x: jnp.ndarray): ...@@ -520,7 +520,7 @@ def unshard(x: jnp.ndarray):
def preprocess(image, dtype): def preprocess(image, dtype):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = jnp.array(image).astype(dtype) / 255.0 image = jnp.array(image).astype(dtype) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
......
...@@ -563,7 +563,7 @@ def unshard(x: jnp.ndarray): ...@@ -563,7 +563,7 @@ def unshard(x: jnp.ndarray):
def preprocess_image(image, dtype): def preprocess_image(image, dtype):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = jnp.array(image).astype(dtype) / 255.0 image = jnp.array(image).astype(dtype) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
...@@ -572,7 +572,7 @@ def preprocess_image(image, dtype): ...@@ -572,7 +572,7 @@ def preprocess_image(image, dtype):
def preprocess_mask(mask, dtype): def preprocess_mask(mask, dtype):
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
mask = mask.resize((w, h)) mask = mask.resize((w, h))
mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0
mask = jnp.expand_dims(mask, axis=(0, 1)) mask = jnp.expand_dims(mask, axis=(0, 1))
......
...@@ -40,7 +40,7 @@ def preprocess(image): ...@@ -40,7 +40,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -19,7 +19,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -19,7 +19,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
...@@ -29,7 +29,7 @@ def preprocess(image): ...@@ -29,7 +29,7 @@ def preprocess(image):
def preprocess_mask(mask, scale_factor=8): def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
......
...@@ -31,7 +31,7 @@ def preprocess(image): ...@@ -31,7 +31,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h)))[None, :] for i in image] image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -41,7 +41,7 @@ def preprocess(image): ...@@ -41,7 +41,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
...@@ -442,7 +442,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -442,7 +442,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
image = [image] image = [image]
else: else:
image = [img for img in image] image = list(image)
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
width, height = image[0].size width, height = image[0].size
......
...@@ -78,7 +78,7 @@ def preprocess(image): ...@@ -78,7 +78,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -42,7 +42,7 @@ logger = logging.get_logger(__name__) ...@@ -42,7 +42,7 @@ logger = logging.get_logger(__name__)
def preprocess_image(image): def preprocess_image(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
...@@ -54,7 +54,7 @@ def preprocess_mask(mask, scale_factor=8): ...@@ -54,7 +54,7 @@ def preprocess_mask(mask, scale_factor=8):
if not isinstance(mask, torch.FloatTensor): if not isinstance(mask, torch.FloatTensor):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
...@@ -76,7 +76,7 @@ def preprocess_mask(mask, scale_factor=8): ...@@ -76,7 +76,7 @@ def preprocess_mask(mask, scale_factor=8):
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
mask = mask.mean(dim=1, keepdim=True) mask = mask.mean(dim=1, keepdim=True)
h, w = mask.shape[-2:] h, w = mask.shape[-2:]
h, w = map(lambda x: x - x % 8, (h, w)) # resize to integer multiple of 8 h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
return mask return mask
......
...@@ -47,7 +47,7 @@ def preprocess(image): ...@@ -47,7 +47,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -38,7 +38,7 @@ def preprocess(image): ...@@ -38,7 +38,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image] image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -180,7 +180,7 @@ def preprocess(image): ...@@ -180,7 +180,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -37,7 +37,7 @@ def preprocess(image): ...@@ -37,7 +37,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
image = [np.array(i.resize((w, h)))[None, :] for i in image] image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -134,7 +134,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -134,7 +134,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
return embeds return embeds
if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
prompt = [p for p in prompt] prompt = list(prompt)
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment