Unverified Commit 41c186d2 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Replace torch.set_grad_enabled by torch.no_grad (#13703)

parent f888e5c3
...@@ -90,44 +90,43 @@ def export( ...@@ -90,44 +90,43 @@ def export(
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
logger.info(f"Using framework PyTorch: {torch.__version__}") logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False) with torch.no_grad():
model.config.return_dict = True model.config.return_dict = True
model.eval() model.eval()
# Check if we need to override certain configuration item # Check if we need to override certain configuration item
if config.values_override is not None: if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)") logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items(): for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}") logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value) setattr(model.config, override_config_key, override_config_value)
# Ensure inputs match # Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True" # TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys()) onnx_outputs = list(config.outputs.keys())
if not inputs_match: if not inputs_match:
raise ValueError("Model and config inputs doesn't match") raise ValueError("Model and config inputs doesn't match")
config.patch_ops() config.patch_ops()
# export can works with named args but the dict containing named args as to be last element of the args tuple # export can works with named args but the dict containing named args as to be last element of the args tuple
export( export(
model, model,
(model_inputs,), (model_inputs,),
f=output.as_posix(), f=output.as_posix(),
input_names=list(config.inputs.keys()), input_names=list(config.inputs.keys()),
output_names=onnx_outputs, output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True, do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()), use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True, enable_onnx_checker=True,
opset_version=opset, opset_version=opset,
) )
config.restore_ops() config.restore_ops()
torch.set_grad_enabled(True)
return matched_inputs, onnx_outputs return matched_inputs, onnx_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment