Unverified Commit 45d9a304 authored by Jay Zhang's avatar Jay Zhang Committed by GitHub
Browse files

[ONNX] Fix ShuffleNetV2 model export issue. (#3158)



* Fix an issue that ShuffleNetV2 model is exported to a wrong ONNX file if dynamic_axes field was provided.

* Add a ut for the bug fix.

* Fix flake8 issue.

* Don't access each element in x.shape, use x.size() instead.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 2e5e058c
...@@ -482,6 +482,17 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -482,6 +482,17 @@ class ONNXExporterTester(unittest.TestCase):
dynamic_axes={"images_tensors": [0, 1, 2]}, dynamic_axes={"images_tensors": [0, 1, 2]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
def test_shufflenet_v2_dynamic_axes(self):
model = models.shufflenet_v2_x0_5(pretrained=True)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)
self.run_model(model, [(dummy_input,), (test_inputs,)],
input_names=["input_images"],
output_names=["output"],
dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}},
tolerate_small_mismatch=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,7 +19,7 @@ model_urls = { ...@@ -19,7 +19,7 @@ model_urls = {
def channel_shuffle(x: Tensor, groups: int) -> Tensor: def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batchsize, num_channels, height, width = x.data.size() batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
# reshape # reshape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment