"vscode:/vscode.git/clone" did not exist on "bb22d546c062ae768a9f54a9eb1675f2a8dcdad9"
Unverified Commit 23d3f78a authored by Andrey Talman's avatar Andrey Talman Committed by GitHub
Browse files

Add more advanced smoke test for project Nova and validation workflows (#7014)



* Add more advanced smoke test

* add torch import

* remove dependency on torch

* Add missing vars

* More code and ufmt
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent f93eb8ff
"""Run smoke tests"""
import os
from pathlib import Path
import torch
import torchvision
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
image_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
print("torchvision version is ", torchvision.__version__)
img = read_image(image_path)
SCRIPT_DIR = Path(__file__).parent
def smoke_test_torchvision() -> None:
print(
"Is torchvision useable?",
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
)
def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.ndim != 3 or img_png.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_torchvision_resnet50_classify() -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg"))
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
expected_category = "German shepherd"
print(f"{category_name}: {100 * score:.1f}%")
if category_name != expected_category:
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
def main() -> None:
print(f"torchvision: {torchvision.__version__}")
smoke_test_torchvision()
smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment