Unverified Commit 0e62c34f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix code format (#7377)

parent caf12f84
...@@ -27,6 +27,7 @@ def smoke_test_torchvision_read_decode() -> None: ...@@ -27,6 +27,7 @@ def smoke_test_torchvision_read_decode() -> None:
if img_png.ndim != 3 or img_png.numel() < 100: if img_png.ndim != 3 or img_png.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_compile() -> None: def smoke_test_compile() -> None:
model = resnet50().cuda() model = resnet50().cuda()
model = torch.compile(model) model = torch.compile(model)
...@@ -34,6 +35,7 @@ def smoke_test_compile() -> None: ...@@ -34,6 +35,7 @@ def smoke_test_compile() -> None:
out = model(x) out = model(x)
print(f"torch.compile model output: {out.shape}") print(f"torch.compile model output: {out.shape}")
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
...@@ -73,6 +75,5 @@ def main() -> None: ...@@ -73,6 +75,5 @@ def main() -> None:
smoke_test_torchvision_resnet50_classify("mps") smoke_test_torchvision_resnet50_classify("mps")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment