Unverified Commit 0e45022a authored by Andrey Talman's avatar Andrey Talman Committed by GitHub
Browse files

Add smoke test Using a simple RN50 with torch.compile (#7359)

parent 924d373c
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
import torch.nn as nn
import torchvision import torchvision
from torchvision.io import read_image from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights from torchvision.models import resnet50, ResNet50_Weights
...@@ -26,6 +27,12 @@ def smoke_test_torchvision_read_decode() -> None: ...@@ -26,6 +27,12 @@ def smoke_test_torchvision_read_decode() -> None:
if img_png.ndim != 3 or img_png.numel() < 100: if img_png.ndim != 3 or img_png.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_compile() -> None:
model = resnet50().cuda()
model = torch.compile(model)
x = torch.randn(1, 3, 224, 224, device="cuda")
out = model(x)
print(f"torch.compile model output: {out.shape}")
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
...@@ -54,14 +61,18 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: ...@@ -54,14 +61,18 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
def main() -> None: def main() -> None:
print(f"torchvision: {torchvision.__version__}") print(f"torchvision: {torchvision.__version__}")
print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
smoke_test_torchvision() smoke_test_torchvision()
smoke_test_torchvision_read_decode() smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify() smoke_test_torchvision_resnet50_classify()
if torch.cuda.is_available(): if torch.cuda.is_available():
smoke_test_torchvision_resnet50_classify("cuda") smoke_test_torchvision_resnet50_classify("cuda")
smoke_test_compile()
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
smoke_test_torchvision_resnet50_classify("mps") smoke_test_torchvision_resnet50_classify("mps")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment