Unverified Commit 1d7b4b60 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Ruff: apply same rules as in transformers (#2827)

* Apply same ruff settings as in transformers

See https://github.com/huggingface/transformers/blob/main/pyproject.toml

Co-authored-by: default avatarAaron Gokaslan <aaronGokaslan@gmail.com>

* Apply new style rules

* Style
Co-authored-by: default avatarAaron Gokaslan <aaronGokaslan@gmail.com>

* style

* remove list, ruff wouldn't auto fix.

---------
Co-authored-by: default avatarAaron Gokaslan <aaronGokaslan@gmail.com>
parent abb22b4e
...@@ -199,24 +199,20 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -199,24 +199,20 @@ class CheckpointMergerPipeline(DiffusionPipeline):
if not attr.startswith("_"): if not attr.startswith("_"):
checkpoint_path_1 = os.path.join(cached_folders[1], attr) checkpoint_path_1 = os.path.join(cached_folders[1], attr)
if os.path.exists(checkpoint_path_1): if os.path.exists(checkpoint_path_1):
files = list( files = [
(
*glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")), *glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
*glob.glob(os.path.join(checkpoint_path_1, "*.bin")), *glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
) ]
)
checkpoint_path_1 = files[0] if len(files) > 0 else None checkpoint_path_1 = files[0] if len(files) > 0 else None
if len(cached_folders) < 3: if len(cached_folders) < 3:
checkpoint_path_2 = None checkpoint_path_2 = None
else: else:
checkpoint_path_2 = os.path.join(cached_folders[2], attr) checkpoint_path_2 = os.path.join(cached_folders[2], attr)
if os.path.exists(checkpoint_path_2): if os.path.exists(checkpoint_path_2):
files = list( files = [
(
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")), *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
*glob.glob(os.path.join(checkpoint_path_2, "*.bin")), *glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
) ]
)
checkpoint_path_2 = files[0] if len(files) > 0 else None checkpoint_path_2 = files[0] if len(files) > 0 else None
# For an attr if both checkpoint_path_1 and 2 are None, ignore. # For an attr if both checkpoint_path_1 and 2 are None, ignore.
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match. # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
......
...@@ -48,7 +48,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -48,7 +48,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
......
...@@ -376,7 +376,7 @@ def get_weighted_text_embeddings( ...@@ -376,7 +376,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image): def preprocess_image(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
...@@ -387,7 +387,7 @@ def preprocess_image(image): ...@@ -387,7 +387,7 @@ def preprocess_image(image):
def preprocess_mask(mask, scale_factor=8): def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
......
...@@ -403,7 +403,7 @@ def get_weighted_text_embeddings( ...@@ -403,7 +403,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image): def preprocess_image(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
...@@ -413,7 +413,7 @@ def preprocess_image(image): ...@@ -413,7 +413,7 @@ def preprocess_image(image):
def preprocess_mask(mask, scale_factor=8): def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
......
...@@ -46,7 +46,7 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -46,7 +46,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
): ):
super().__init__() super().__init__()
decoder_pipe_kwargs = dict(image_encoder=None) if decoder_pipe_kwargs is None else decoder_pipe_kwargs decoder_pipe_kwargs = {"image_encoder": None} if decoder_pipe_kwargs is None else decoder_pipe_kwargs
decoder_pipe_kwargs["torch_dtype"] = decoder_pipe_kwargs.get("torch_dtype", None) or prior.dtype decoder_pipe_kwargs["torch_dtype"] = decoder_pipe_kwargs.get("torch_dtype", None) or prior.dtype
......
...@@ -673,7 +673,7 @@ def main(): ...@@ -673,7 +673,7 @@ def main():
examples["edited_pixel_values"] = edited_images examples["edited_pixel_values"] = edited_images
# Preprocess the captions. # Preprocess the captions.
captions = [caption for caption in examples[edit_prompt_column]] captions = list(examples[edit_prompt_column])
examples["input_ids"] = tokenize_captions(captions) examples["input_ids"] = tokenize_captions(captions)
return examples return examples
......
...@@ -4,17 +4,17 @@ import tqdm ...@@ -4,17 +4,17 @@ import tqdm
from diffusers.experimental import ValueGuidedRLPipeline from diffusers.experimental import ValueGuidedRLPipeline
config = dict( config = {
n_samples=64, "n_samples": 64,
horizon=32, "horizon": 32,
num_inference_steps=20, "num_inference_steps": 20,
n_guide_steps=2, # can set to 0 for faster sampling, does not use value network "n_guide_steps": 2, # can set to 0 for faster sampling, does not use value network
scale_grad_by_std=True, "scale_grad_by_std": True,
scale=0.1, "scale": 0.1,
eta=0.0, "eta": 0.0,
t_grad_cutoff=2, "t_grad_cutoff": 2,
device="cpu", "device": "cpu",
) }
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,8 +4,8 @@ target-version = ['py37'] ...@@ -4,8 +4,8 @@ target-version = ['py37']
[tool.ruff] [tool.ruff]
# Never enforce `E501` (line length violations). # Never enforce `E501` (line length violations).
ignore = ["E501", "E741", "W605"] ignore = ["C901", "E501", "E741", "W605"]
select = ["E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
line-length = 119 line-length = 119
# Ignore import violations in all `__init__.py` files. # Ignore import violations in all `__init__.py` files.
......
...@@ -404,7 +404,7 @@ if __name__ == "__main__": ...@@ -404,7 +404,7 @@ if __name__ == "__main__":
config = json.loads(f.read()) config = json.loads(f.read())
# unet case # unet case
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys()) key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
if "encoder" in key_prefix_set and "decoder" in key_prefix_set: if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config) converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
else: else:
......
...@@ -24,29 +24,29 @@ def unet(hor): ...@@ -24,29 +24,29 @@ def unet(hor):
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
state_dict = model.state_dict() state_dict = model.state_dict()
config = dict( config = {
down_block_types=down_block_types, "down_block_types": down_block_types,
block_out_channels=block_out_channels, "block_out_channels": block_out_channels,
up_block_types=up_block_types, "up_block_types": up_block_types,
layers_per_block=1, "layers_per_block": 1,
use_timestep_embedding=True, "use_timestep_embedding": True,
out_block_type="OutConv1DBlock", "out_block_type": "OutConv1DBlock",
norm_num_groups=8, "norm_num_groups": 8,
downsample_each_block=False, "downsample_each_block": False,
in_channels=14, "in_channels": 14,
out_channels=14, "out_channels": 14,
extra_in_channels=0, "extra_in_channels": 0,
time_embedding_type="positional", "time_embedding_type": "positional",
flip_sin_to_cos=False, "flip_sin_to_cos": False,
freq_shift=1, "freq_shift": 1,
sample_size=65536, "sample_size": 65536,
mid_block_type="MidResTemporalBlock1D", "mid_block_type": "MidResTemporalBlock1D",
act_fn="mish", "act_fn": "mish",
) }
hf_value_function = UNet1DModel(**config) hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}") print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items(): for k, v in mapping.items():
state_dict[v] = state_dict.pop(k) state_dict[v] = state_dict.pop(k)
hf_value_function.load_state_dict(state_dict) hf_value_function.load_state_dict(state_dict)
...@@ -57,25 +57,25 @@ def unet(hor): ...@@ -57,25 +57,25 @@ def unet(hor):
def value_function(): def value_function():
config = dict( config = {
in_channels=14, "in_channels": 14,
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
up_block_types=(), "up_block_types": (),
out_block_type="ValueFunction", "out_block_type": "ValueFunction",
mid_block_type="ValueFunctionMidBlock1D", "mid_block_type": "ValueFunctionMidBlock1D",
block_out_channels=(32, 64, 128, 256), "block_out_channels": (32, 64, 128, 256),
layers_per_block=1, "layers_per_block": 1,
downsample_each_block=True, "downsample_each_block": True,
sample_size=65536, "sample_size": 65536,
out_channels=14, "out_channels": 14,
extra_in_channels=0, "extra_in_channels": 0,
time_embedding_type="positional", "time_embedding_type": "positional",
use_timestep_embedding=True, "use_timestep_embedding": True,
flip_sin_to_cos=False, "flip_sin_to_cos": False,
freq_shift=1, "freq_shift": 1,
norm_num_groups=8, "norm_num_groups": 8,
act_fn="mish", "act_fn": "mish",
) }
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
state_dict = model state_dict = model
...@@ -83,7 +83,7 @@ def value_function(): ...@@ -83,7 +83,7 @@ def value_function():
print(f"length of state dict: {len(state_dict.keys())}") print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items(): for k, v in mapping.items():
state_dict[v] = state_dict.pop(k) state_dict[v] = state_dict.pop(k)
......
...@@ -246,19 +246,19 @@ def create_unet_diffusers_config(original_config, image_size: int): ...@@ -246,19 +246,19 @@ def create_unet_diffusers_config(original_config, image_size: int):
) )
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None
config = dict( config = {
sample_size=image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
in_channels=unet_params.in_channels, "in_channels": unet_params.in_channels,
out_channels=unet_params.out_channels, "out_channels": unet_params.out_channels,
down_block_types=tuple(down_block_types), "down_block_types": tuple(down_block_types),
up_block_types=tuple(up_block_types), "up_block_types": tuple(up_block_types),
block_out_channels=tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
layers_per_block=unet_params.num_res_blocks, "layers_per_block": unet_params.num_res_blocks,
cross_attention_dim=cross_attention_dim, "cross_attention_dim": cross_attention_dim,
class_embed_type=class_embed_type, "class_embed_type": class_embed_type,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
class_embeddings_concat=class_embeddings_concat, "class_embeddings_concat": class_embeddings_concat,
) }
return config return config
...@@ -278,17 +278,17 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): ...@@ -278,17 +278,17 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
config = dict( config = {
sample_size=image_size, "sample_size": image_size,
in_channels=vae_params.in_channels, "in_channels": vae_params.in_channels,
out_channels=vae_params.out_ch, "out_channels": vae_params.out_ch,
down_block_types=tuple(down_block_types), "down_block_types": tuple(down_block_types),
up_block_types=tuple(up_block_types), "up_block_types": tuple(up_block_types),
block_out_channels=tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
latent_channels=vae_params.z_channels, "latent_channels": vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks, "layers_per_block": vae_params.num_res_blocks,
scaling_factor=float(scaling_factor), "scaling_factor": float(scaling_factor),
) }
return config return config
...@@ -670,18 +670,18 @@ def create_transformers_vocoder_config(original_config): ...@@ -670,18 +670,18 @@ def create_transformers_vocoder_config(original_config):
""" """
vocoder_params = original_config.model.params.vocoder_config.params vocoder_params = original_config.model.params.vocoder_config.params
config = dict( config = {
model_in_dim=vocoder_params.num_mels, "model_in_dim": vocoder_params.num_mels,
sampling_rate=vocoder_params.sampling_rate, "sampling_rate": vocoder_params.sampling_rate,
upsample_initial_channel=vocoder_params.upsample_initial_channel, "upsample_initial_channel": vocoder_params.upsample_initial_channel,
upsample_rates=list(vocoder_params.upsample_rates), "upsample_rates": list(vocoder_params.upsample_rates),
upsample_kernel_sizes=list(vocoder_params.upsample_kernel_sizes), "upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
resblock_kernel_sizes=list(vocoder_params.resblock_kernel_sizes), "resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
resblock_dilation_sizes=[ "resblock_dilation_sizes": [
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
], ],
normalize_before=False, "normalize_before": False,
) }
return config return config
......
...@@ -280,17 +280,17 @@ def create_image_unet_diffusers_config(unet_params): ...@@ -280,17 +280,17 @@ def create_image_unet_diffusers_config(unet_params):
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict( config = {
sample_size=None, "sample_size": None,
in_channels=unet_params.input_channels, "in_channels": unet_params.input_channels,
out_channels=unet_params.output_channels, "out_channels": unet_params.output_channels,
down_block_types=tuple(down_block_types), "down_block_types": tuple(down_block_types),
up_block_types=tuple(up_block_types), "up_block_types": tuple(up_block_types),
block_out_channels=tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0], "layers_per_block": unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim, "cross_attention_dim": unet_params.context_dim,
attention_head_dim=unet_params.num_heads, "attention_head_dim": unet_params.num_heads,
) }
return config return config
...@@ -319,17 +319,17 @@ def create_text_unet_diffusers_config(unet_params): ...@@ -319,17 +319,17 @@ def create_text_unet_diffusers_config(unet_params):
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict( config = {
sample_size=None, "sample_size": None,
in_channels=(unet_params.input_channels, 1, 1), "in_channels": (unet_params.input_channels, 1, 1),
out_channels=(unet_params.output_channels, 1, 1), "out_channels": (unet_params.output_channels, 1, 1),
down_block_types=tuple(down_block_types), "down_block_types": tuple(down_block_types),
up_block_types=tuple(up_block_types), "up_block_types": tuple(up_block_types),
block_out_channels=tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0], "layers_per_block": unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim, "cross_attention_dim": unet_params.context_dim,
attention_head_dim=unet_params.num_heads, "attention_head_dim": unet_params.num_heads,
) }
return config return config
...@@ -343,16 +343,16 @@ def create_vae_diffusers_config(vae_params): ...@@ -343,16 +343,16 @@ def create_vae_diffusers_config(vae_params):
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict( config = {
sample_size=vae_params.resolution, "sample_size": vae_params.resolution,
in_channels=vae_params.in_channels, "in_channels": vae_params.in_channels,
out_channels=vae_params.out_ch, "out_channels": vae_params.out_ch,
down_block_types=tuple(down_block_types), "down_block_types": tuple(down_block_types),
up_block_types=tuple(up_block_types), "up_block_types": tuple(up_block_types),
block_out_channels=tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
latent_channels=vae_params.z_channels, "latent_channels": vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks, "layers_per_block": vae_params.num_res_blocks,
) }
return config return config
......
...@@ -420,7 +420,7 @@ class ConfigMixin: ...@@ -420,7 +420,7 @@ class ConfigMixin:
@classmethod @classmethod
def extract_init_dict(cls, config_dict, **kwargs): def extract_init_dict(cls, config_dict, **kwargs):
# 0. Copy origin config dict # 0. Copy origin config dict
original_dict = {k: v for k, v in config_dict.items()} original_dict = dict(config_dict.items())
# 1. Retrieve expected config attributes from __init__ signature # 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls) expected_keys = cls._get_init_keys(cls)
...@@ -610,7 +610,7 @@ def flax_register_to_config(cls): ...@@ -610,7 +610,7 @@ def flax_register_to_config(cls):
) )
# Ignore private kwargs in the init. Retrieve all passed attributes # Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items()} init_kwargs = dict(kwargs.items())
# Retrieve default values # Retrieve default values
fields = dataclasses.fields(self) fields = dataclasses.fields(self)
......
...@@ -52,13 +52,13 @@ class ValueGuidedRLPipeline(DiffusionPipeline): ...@@ -52,13 +52,13 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
self.scheduler = scheduler self.scheduler = scheduler
self.env = env self.env = env
self.data = env.get_dataset() self.data = env.get_dataset()
self.means = dict() self.means = {}
for key in self.data.keys(): for key in self.data.keys():
try: try:
self.means[key] = self.data[key].mean() self.means[key] = self.data[key].mean()
except: # noqa: E722 except: # noqa: E722
pass pass
self.stds = dict() self.stds = {}
for key in self.data.keys(): for key in self.data.keys():
try: try:
self.stds[key] = self.data[key].std() self.stds[key] = self.data[key].std()
......
...@@ -99,7 +99,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -99,7 +99,7 @@ class VaeImageProcessor(ConfigMixin):
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
""" """
w, h = images.size w, h = images.size
w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor w, h = (x - x % self.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample]) images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
return images return images
......
...@@ -37,7 +37,7 @@ class AttnProcsLayers(torch.nn.Module): ...@@ -37,7 +37,7 @@ class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]): def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__() super().__init__()
self.layers = torch.nn.ModuleList(state_dict.values()) self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = {k: v for k, v in enumerate(state_dict.keys())} self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# we add a hook to state_dict() and load_state_dict() so that the # we add a hook to state_dict() and load_state_dict() so that the
......
...@@ -647,7 +647,7 @@ class ModelMixin(torch.nn.Module): ...@@ -647,7 +647,7 @@ class ModelMixin(torch.nn.Module):
): ):
# Retrieve missing & unexpected_keys # Retrieve missing & unexpected_keys
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
loaded_keys = [k for k in state_dict.keys()] loaded_keys = list(state_dict.keys())
expected_keys = list(model_state_dict.keys()) expected_keys = list(model_state_dict.keys())
......
...@@ -74,7 +74,7 @@ def preprocess(image): ...@@ -74,7 +74,7 @@ def preprocess(image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
......
...@@ -201,12 +201,12 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -201,12 +201,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
images = images.cpu().permute(0, 2, 3, 1).numpy() images = images.cpu().permute(0, 2, 3, 1).numpy()
images = (images * 255).round().astype("uint8") images = (images * 255).round().astype("uint8")
images = list( images = list(
map(lambda _: Image.fromarray(_[:, :, 0]), images) (Image.fromarray(_[:, :, 0]) for _ in images)
if images.shape[3] == 1 if images.shape[3] == 1
else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images) else (Image.fromarray(_, mode="RGB").convert("L") for _ in images)
) )
audios = list(map(lambda _: self.mel.image_to_audio(_), images)) audios = [self.mel.image_to_audio(_) for _ in images]
if not return_dict: if not return_dict:
return images, (self.mel.get_sample_rate(), audios) return images, (self.mel.get_sample_rate(), audios)
......
...@@ -21,7 +21,7 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput ...@@ -21,7 +21,7 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment