Unverified Commit 6381f7b2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve smoke test (#7550)

parent 9bc094ea
"""Run smoke tests"""
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torchvision
from torchvision.io import read_image
from torchvision.io import decode_jpeg, read_file, read_image
from torchvision.models import resnet50, ResNet50_Weights
SCRIPT_DIR = Path(__file__).parent
......@@ -22,13 +20,20 @@ def smoke_test_torchvision() -> None:
def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.ndim != 3 or img_png.numel() < 100:
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_torchvision_decode_jpeg_cuda():
img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
img_jpg = decode_jpeg(img_jpg_data, device="cuda")
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
def smoke_test_compile() -> None:
try:
model = resnet50().cuda()
......@@ -77,6 +82,7 @@ def main() -> None:
smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify()
if torch.cuda.is_available():
smoke_test_torchvision_decode_jpeg_cuda()
smoke_test_torchvision_resnet50_classify("cuda")
smoke_test_compile()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment