Unverified Commit 41c186d2 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Replace torch.set_grad_enabled by torch.no_grad (#13703)

parent f888e5c3
......@@ -90,7 +90,7 @@ def export(
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False)
with torch.no_grad():
model.config.return_dict = True
model.eval()
......@@ -127,7 +127,6 @@ def export(
)
config.restore_ops()
torch.set_grad_enabled(True)
return matched_inputs, onnx_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment