Unverified Commit 68818a33 authored by Bill Wu's avatar Bill Wu Committed by GitHub
Browse files

[Model Compression] Expand export_model arguments: dummy input and onnx opset_version (#3968)

parent deef0c42
...@@ -375,7 +375,8 @@ class Pruner(Compressor): ...@@ -375,7 +375,8 @@ class Pruner(Compressor):
wrapper.to(layer.module.weight.device) wrapper.to(layer.module.weight.device)
return wrapper return wrapper
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None): def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None,
dummy_input=None, opset_version=None):
""" """
Export pruned model weights, masks and onnx model(optional) Export pruned model weights, masks and onnx model(optional)
...@@ -388,10 +389,21 @@ class Pruner(Compressor): ...@@ -388,10 +389,21 @@ class Pruner(Compressor):
onnx_path : str onnx_path : str
(optional) path to save onnx model (optional) path to save onnx model
input_shape : list or tuple input_shape : list or tuple
input shape to onnx model input shape to onnx model, used for creating a dummy input tensor for torch.onnx.export
if the input has a complex structure (e.g., a tuple), please directly create the input and
pass it to dummy_input instead
note: this argument is deprecated and will be removed; please use dummy_input instead
device : torch.device device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file. device of the model, where to place the dummy input tensor for exporting onnx file;
the tensor is placed on cpu if ```device``` is None the tensor is placed on cpu if ```device``` is None
only useful when both onnx_path and input_shape are passed
note: this argument is deprecated and will be removed; please use dummy_input instead
dummy_input: torch.Tensor or tuple
dummy input to the onnx model; used when input_shape is not enough to specify dummy input
user should ensure that the dummy_input is on the same device as the model
opset_version: int
opset_version parameter for torch.onnx.export; only useful when onnx_path is not None
if not passed, torch.onnx.export will use its default opset_version
""" """
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
mask_dict = {} mask_dict = {}
...@@ -412,17 +424,31 @@ class Pruner(Compressor): ...@@ -412,17 +424,31 @@ class Pruner(Compressor):
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path) _logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None: if mask_path is not None:
torch.save(mask_dict, mask_path) torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path) _logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None: if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model' assert input_shape is not None or dummy_input is not None,\
# input info needed 'input_shape or dummy_input must be specified to export onnx model'
if device is None: # create dummy_input using input_shape if input_shape is not passed
device = torch.device('cpu') if dummy_input is None:
input_data = torch.Tensor(*input_shape) _logger.warning("""The argument input_shape and device will be removed in the future.
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) Please create a dummy input and pass it to dummy_input instead.""")
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape).to(device)
else:
input_data = dummy_input
if opset_version is not None:
torch.onnx.export(self.bound_model, input_data, onnx_path, opset_version=opset_version)
else:
torch.onnx.export(self.bound_model, input_data, onnx_path)
if dummy_input is None:
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
else:
_logger.info('Model in onnx saved to %s', onnx_path)
self._wrap_model() self._wrap_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment