Unverified Commit 81f0c0c9 authored by Andrey Talman's avatar Andrey Talman Committed by GitHub
Browse files

Execute compile smoke test only on linux, check that windows throws an exception (#7386)

parent 0e62c34f
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import os import os
from pathlib import Path from pathlib import Path
from sys import platform
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -29,11 +30,17 @@ def smoke_test_torchvision_read_decode() -> None: ...@@ -29,11 +30,17 @@ def smoke_test_torchvision_read_decode() -> None:
def smoke_test_compile() -> None: def smoke_test_compile() -> None:
try:
model = resnet50().cuda() model = resnet50().cuda()
model = torch.compile(model) model = torch.compile(model)
x = torch.randn(1, 3, 224, 224, device="cuda") x = torch.randn(1, 3, 224, 224, device="cuda")
out = model(x) out = model(x)
print(f"torch.compile model output: {out.shape}") print(f"torch.compile model output: {out.shape}")
except RuntimeError:
if platform == "win32":
print("Successfully caught torch.compile RuntimeError on win")
else:
raise
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment