Unverified Commit 09030c02 authored by Andrey Talman's avatar Andrey Talman Committed by GitHub
Browse files

Add cuda resnet50 test to smoke test (#7020)

* Add cuda resnet50 test

* Fix path

* Tune vision smoke test
parent 23d3f78a
...@@ -17,7 +17,6 @@ def smoke_test_torchvision() -> None: ...@@ -17,7 +17,6 @@ def smoke_test_torchvision() -> None:
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]), all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
) )
def smoke_test_torchvision_read_decode() -> None: def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.ndim != 3 or img_jpg.numel() < 100: if img_jpg.ndim != 3 or img_jpg.numel() < 100:
...@@ -26,13 +25,12 @@ def smoke_test_torchvision_read_decode() -> None: ...@@ -26,13 +25,12 @@ def smoke_test_torchvision_read_decode() -> None:
if img_png.ndim != 3 or img_png.numel() < 100: if img_png.ndim != 3 or img_png.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
def smoke_test_torchvision_resnet50_classify() -> None: img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg"))
# Step 1: Initialize model with the best available weights # Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights) model = resnet50(weights=weights).to(device)
model.eval() model.eval()
# Step 2: Initialize the inference transforms # Step 2: Initialize the inference transforms
...@@ -47,17 +45,19 @@ def smoke_test_torchvision_resnet50_classify() -> None: ...@@ -47,17 +45,19 @@ def smoke_test_torchvision_resnet50_classify() -> None:
score = prediction[class_id].item() score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id] category_name = weights.meta["categories"][class_id]
expected_category = "German shepherd" expected_category = "German shepherd"
print(f"{category_name}: {100 * score:.1f}%") print(f"{category_name} ({device}): {100 * score:.1f}%")
if category_name != expected_category: if category_name != expected_category:
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}") raise RuntimeError(
f"Failed ResNet50 classify {category_name} Expected: {expected_category}"
)
def main() -> None: def main() -> None:
print(f"torchvision: {torchvision.__version__}") print(f"torchvision: {torchvision.__version__}")
smoke_test_torchvision() smoke_test_torchvision()
smoke_test_torchvision_read_decode() smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify() smoke_test_torchvision_resnet50_classify()
if torch.cuda.is_available():
smoke_test_torchvision_resnet50_classify("cuda")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment