Unverified Commit 89793a97 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Style the `scripts` directory (#250)

Style scripts
parent 365f7523
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src export PYTHONPATH = src
check_dirs := examples tests src utils check_dirs := examples scripts src tests utils
modified_only_fixup: modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
......
...@@ -15,12 +15,15 @@ ...@@ -15,12 +15,15 @@
""" Conversion script for the LDM checkpoints. """ """ Conversion script for the LDM checkpoints. """
import argparse import argparse
import os
import json import json
import os
import torch import torch
from diffusers import UNet2DModel, UNet2DConditionModel
from diffusers import UNet2DConditionModel, UNet2DModel
from transformers.file_utils import has_file from transformers.file_utils import has_file
do_only_config = False do_only_config = False
do_only_weights = True do_only_weights = True
do_only_renaming = False do_only_renaming = False
...@@ -37,9 +40,7 @@ if __name__ == "__main__": ...@@ -37,9 +40,7 @@ if __name__ == "__main__":
help="The config json file corresponding to the architecture.", help="The config json file corresponding to the architecture.",
) )
parser.add_argument( parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
args = parser.parse_args() args = parser.parse_args()
......
import argparse import argparse
import OmegaConf
import torch import torch
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler import OmegaConf
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
def convert_ldm_original(checkpoint_path, config_path, output_path): def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path) config = OmegaConf.load(config_path)
...@@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path): ...@@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
for key in keys: for key in keys:
if key.startswith(first_stage_key): if key.startswith(first_stage_key):
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key] first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
# extract state_dict for UNetLDM # extract state_dict for UNetLDM
unet_state_dict = {} unet_state_dict = {}
unet_key = "model.diffusion_model." unet_key = "model.diffusion_model."
for key in keys: for key in keys:
if key.startswith(unet_key): if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key] unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params unet_init_args = config.model.params.unet_config.params
...@@ -53,4 +54,3 @@ if __name__ == "__main__": ...@@ -53,4 +54,3 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path) convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
import argparse import argparse
import json import json
import torch import torch
from diffusers import UNet2DModel
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
def convert_ncsnpp_checkpoint(checkpoint, config): def convert_ncsnpp_checkpoint(checkpoint, config):
......
from huggingface_hub import HfApi
from transformers.file_utils import has_file
from diffusers import UNet2DModel
import random import random
import torch import torch
from diffusers import UNet2DModel
from huggingface_hub import HfApi
api = HfApi() api = HfApi()
results = {} results = {}
results["google_ddpm_cifar10_32"] = torch.tensor([-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467, # fmt: off
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189, results["google_ddpm_cifar10_32"] = torch.tensor([
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839, -0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557]) 1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
results["google_ddpm_ema_bedroom_256"] = torch.tensor([-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436, -1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208, 0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948, ])
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365]) results["google_ddpm_ema_bedroom_256"] = torch.tensor([
results["CompVis_ldm_celebahq_256"] = torch.tensor([-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869, -2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304, 1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925, -2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943]) 2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365
results["google_ncsnpp_ffhq_1024"] = torch.tensor([ 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172, ])
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309, results["CompVis_ldm_celebahq_256"] = torch.tensor([
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805, -0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505]) -0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
results["google_ncsnpp_bedroom_256"] = torch.tensor([ 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133, -0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395, 0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559, ])
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386]) results["google_ncsnpp_ffhq_1024"] = torch.tensor([
results["google_ncsnpp_celebahq_256"] = torch.tensor([ 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078, 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330, -0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683, 0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431]) -0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505
results["google_ncsnpp_church_256"] = torch.tensor([ 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042, ])
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398, results["google_ncsnpp_bedroom_256"] = torch.tensor([
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574, 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390]) -0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
results["google_ncsnpp_ffhq_256"] = torch.tensor([ 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042, 0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290, -0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746, ])
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473]) results["google_ncsnpp_celebahq_256"] = torch.tensor([
results["google_ddpm_cat_256"] = torch.tensor([-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330, 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243, -0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810, 0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251]) -0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431
results["google_ddpm_celebahq_256"] = torch.tensor([-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324, ])
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181, results["google_ncsnpp_church_256"] = torch.tensor([
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259, 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266]) -0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
results["google_ddpm_ema_celebahq_256"] = torch.tensor([-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212, 0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027, -0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131, ])
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355]) results["google_ncsnpp_ffhq_256"] = torch.tensor([
results["google_ddpm_church_256"] = torch.tensor([-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959, 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351, -0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341, 0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066]) -0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473
results["google_ddpm_bedroom_256"] = torch.tensor([-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740, ])
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398, results["google_ddpm_cat_256"] = torch.tensor([
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395, -1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243]) 1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
results["google_ddpm_ema_church_256"] = torch.tensor([-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336, -2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908, 1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560, results["google_ddpm_celebahq_256"] = torch.tensor([
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343]) -1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344, 0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391, -2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439, 1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219]) ])
results["google_ddpm_ema_celebahq_256"] = torch.tensor([
-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355
])
results["google_ddpm_church_256"] = torch.tensor([
-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066
])
results["google_ddpm_bedroom_256"] = torch.tensor([
-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243
])
results["google_ddpm_ema_church_256"] = torch.tensor([
-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343
])
results["google_ddpm_ema_cat_256"] = torch.tensor([
-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219
])
# fmt: on
models = api.list_models(filter="diffusers") models = api.list_models(filter="diffusers")
for mod in models: for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1] local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
print(f"Started running {mod.modelId}!!!") print(f"Started running {mod.modelId}!!!")
if mod.modelId.startswith("CompVis"): if mod.modelId.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet") model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
else: else:
model = UNet2DModel.from_pretrained(local_checkpoint) model = UNet2DModel.from_pretrained(local_checkpoint)
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0]) time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad(): with torch.no_grad():
logits = model(noise, time_step)['sample'] logits = model(noise, time_step)["sample"]
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3) assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
)
print(f"{mod.modelId} has passed succesfully!!!") print(f"{mod.modelId} has passed succesfully!!!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment