Unverified Commit 23d3f78a authored by Andrey Talman's avatar Andrey Talman Committed by GitHub
Browse files

Add more advanced smoke test for project Nova and validation workflows (#7014)



* Add more advanced smoke test

* add torch import

* remove dependency on torch

* Add missing vars

* More code and ufmt
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent f93eb8ff
"""Run smoke tests""" """Run smoke tests"""
import os import os
from pathlib import Path
import torch
import torchvision import torchvision
from torchvision.io import read_image from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
image_path = os.path.join( SCRIPT_DIR = Path(__file__).parent
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
print("torchvision version is ", torchvision.__version__) def smoke_test_torchvision() -> None:
img = read_image(image_path) print(
"Is torchvision useable?",
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
)
def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.ndim != 3 or img_png.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_torchvision_resnet50_classify() -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg"))
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
expected_category = "German shepherd"
print(f"{category_name}: {100 * score:.1f}%")
if category_name != expected_category:
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
def main() -> None:
print(f"torchvision: {torchvision.__version__}")
smoke_test_torchvision()
smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment