Unverified Commit c8cd3ff9 authored by Andrey Talman's avatar Andrey Talman Committed by GitHub
Browse files

Add cpu smoke_test_torchvision_decode_jpeg test (#7583)

parent e012579d
...@@ -26,10 +26,9 @@ def smoke_test_torchvision_read_decode() -> None: ...@@ -26,10 +26,9 @@ def smoke_test_torchvision_read_decode() -> None:
if img_png.shape != (4, 471, 354): if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
def smoke_test_torchvision_decode_jpeg_cuda():
img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
img_jpg = decode_jpeg(img_jpg_data, device="cuda") img_jpg = decode_jpeg(img_jpg_data, device=device)
if img_jpg.shape != (3, 606, 517): if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}") raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
...@@ -81,8 +80,9 @@ def main() -> None: ...@@ -81,8 +80,9 @@ def main() -> None:
smoke_test_torchvision() smoke_test_torchvision()
smoke_test_torchvision_read_decode() smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify() smoke_test_torchvision_resnet50_classify()
smoke_test_torchvision_decode_jpeg()
if torch.cuda.is_available(): if torch.cuda.is_available():
smoke_test_torchvision_decode_jpeg_cuda() smoke_test_torchvision_decode_jpeg("cuda")
smoke_test_torchvision_resnet50_classify("cuda") smoke_test_torchvision_resnet50_classify("cuda")
smoke_test_compile() smoke_test_compile()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment