"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "16a32c9dab03d41204120b63cdae71c40b279bdf"
Unverified Commit 56e772ab authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Use model_info.id instead of model_info.modelId (#8912)

Mention model_info.id instead of model_info.modelId
parent fe794894
...@@ -103,12 +103,12 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([ ...@@ -103,12 +103,12 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([
models = api.list_models(filter="diffusers") models = api.list_models(filter="diffusers")
for mod in models: for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1] local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1]
print(f"Started running {mod.modelId}!!!") print(f"Started running {mod.id}!!!")
if mod.modelId.startswith("CompVis"): if mod.id.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet") model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
else: else:
model = UNet2DModel.from_pretrained(local_checkpoint) model = UNet2DModel.from_pretrained(local_checkpoint)
...@@ -122,6 +122,6 @@ for mod in models: ...@@ -122,6 +122,6 @@ for mod in models:
logits = model(noise, time_step).sample logits = model(noise, time_step).sample
assert torch.allclose( assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3 logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3
) )
print(f"{mod.modelId} has passed successfully!!!") print(f"{mod.id} has passed successfully!!!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment