smoke_test.py 2.2 KB
Newer Older
1
2
"""Run smoke tests"""

3
import os
4
from pathlib import Path
5

6
import torch
soumith's avatar
soumith committed
7
import torchvision
8
from torchvision.io import read_image
9
from torchvision.models import resnet50, ResNet50_Weights
10

11
12
13
14
15
SCRIPT_DIR = Path(__file__).parent


def smoke_test_torchvision() -> None:
    print(
16
        "Is torchvision usable?",
17
18
19
        all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
    )

20

21
22
23
24
25
26
27
28
def smoke_test_torchvision_read_decode() -> None:
    img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
    if img_jpg.ndim != 3 or img_jpg.numel() < 100:
        raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
    img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
    if img_png.ndim != 3 or img_png.numel() < 100:
        raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")

29

30
31
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
    img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
32
33
34

    # Step 1: Initialize model with the best available weights
    weights = ResNet50_Weights.DEFAULT
35
    model = resnet50(weights=weights).to(device)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    expected_category = "German shepherd"
50
    print(f"{category_name} ({device}): {100 * score:.1f}%")
51
    if category_name != expected_category:
52
53
        raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")

54
55
56
57
58
59

def main() -> None:
    print(f"torchvision: {torchvision.__version__}")
    smoke_test_torchvision()
    smoke_test_torchvision_read_decode()
    smoke_test_torchvision_resnet50_classify()
60
61
    if torch.cuda.is_available():
        smoke_test_torchvision_resnet50_classify("cuda")
62

63

64
65
if __name__ == "__main__":
    main()