Unverified Commit 6519150c authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Add image height and width to ONNX dynamic axes (#18915)

parent 737f6ad1
...@@ -194,7 +194,7 @@ class BeitOnnxConfig(OnnxConfig): ...@@ -194,7 +194,7 @@ class BeitOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
......
...@@ -332,7 +332,7 @@ class CLIPOnnxConfig(OnnxConfig): ...@@ -332,7 +332,7 @@ class CLIPOnnxConfig(OnnxConfig):
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}),
] ]
) )
......
...@@ -117,7 +117,7 @@ class ConvNextOnnxConfig(OnnxConfig): ...@@ -117,7 +117,7 @@ class ConvNextOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
......
...@@ -193,7 +193,7 @@ class Data2VecVisionOnnxConfig(OnnxConfig): ...@@ -193,7 +193,7 @@ class Data2VecVisionOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
......
...@@ -137,7 +137,7 @@ class DeiTOnnxConfig(OnnxConfig): ...@@ -137,7 +137,7 @@ class DeiTOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
......
...@@ -223,7 +223,7 @@ class DetrOnnxConfig(OnnxConfig): ...@@ -223,7 +223,7 @@ class DetrOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
("pixel_mask", {0: "batch"}), ("pixel_mask", {0: "batch"}),
] ]
) )
......
...@@ -203,7 +203,7 @@ class LayoutLMv3OnnxConfig(OnnxConfig): ...@@ -203,7 +203,7 @@ class LayoutLMv3OnnxConfig(OnnxConfig):
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}), ("bbox", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch", 1: "sequence"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
else: else:
......
...@@ -137,7 +137,7 @@ class LevitOnnxConfig(OnnxConfig): ...@@ -137,7 +137,7 @@ class LevitOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
......
...@@ -171,7 +171,7 @@ class MobileViTOnnxConfig(OnnxConfig): ...@@ -171,7 +171,7 @@ class MobileViTOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("pixel_values", {0: "batch", 1: "num_channels"})]) return OrderedDict([("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"})])
@property @property
def outputs(self) -> Mapping[str, Mapping[int, str]]: def outputs(self) -> Mapping[str, Mapping[int, str]]:
......
...@@ -105,7 +105,7 @@ class ResNetOnnxConfig(OnnxConfig): ...@@ -105,7 +105,7 @@ class ResNetOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
......
...@@ -135,7 +135,7 @@ class ViTOnnxConfig(OnnxConfig): ...@@ -135,7 +135,7 @@ class ViTOnnxConfig(OnnxConfig):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict( return OrderedDict(
[ [
("pixel_values", {0: "batch", 1: "num_channels"}), ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
] ]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment