"docs/vscode:/vscode.git/clone" did not exist on "e7ed7ffdcb66c78d3437ed4c3a63c3640f50f436"
Unverified Commit 3080bb47 authored by Mohit Sharma's avatar Mohit Sharma Committed by GitHub
Browse files

Add onnx support for VisionEncoderDecoder (#19254)



* Add onnx support for VisionEncoderDecoder

* Add onnx support for VisionEncoderDecoder

* Removed unused import

* Rename encoder hidden state
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* Update docstrings and removed redundant code

* Added test function for enc-dec models

* Update doc string text
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* fixed code style
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent 298f6a98
...@@ -96,6 +96,7 @@ Ready-made configurations include the following architectures: ...@@ -96,6 +96,7 @@ Ready-made configurations include the following architectures:
- SqueezeBERT - SqueezeBERT
- Swin Transformer - Swin Transformer
- T5 - T5
- Vision Encoder decoder
- ViT - ViT
- XLM - XLM
- XLM-RoBERTa - XLM-RoBERTa
...@@ -294,6 +295,13 @@ that can be used for fast autoregressive decoding. ...@@ -294,6 +295,13 @@ that can be used for fast autoregressive decoding.
</Tip> </Tip>
<Tip>
For `VisionEncoderDecoder` type models, the encoder and decoder parts are
exported separately as two ONNX files named `encoder_model.onnx` and `decoder_model.onnx` respectively.
</Tip>
## Exporting a model for an unsupported architecture ## Exporting a model for an unsupported architecture
......
...@@ -27,7 +27,9 @@ from ...utils import ( ...@@ -27,7 +27,9 @@ from ...utils import (
) )
_import_structure = {"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"]} _import_structure = {
"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig", "VisionEncoderDecoderOnnxConfig"]
}
try: try:
if not is_torch_available(): if not is_torch_available():
...@@ -54,7 +56,7 @@ else: ...@@ -54,7 +56,7 @@ else:
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"] _import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig, VisionEncoderDecoderOnnxConfig
try: try:
if not is_torch_available(): if not is_torch_available():
......
...@@ -15,12 +15,19 @@ ...@@ -15,12 +15,19 @@
# limitations under the License. # limitations under the License.
import copy import copy
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
if TYPE_CHECKING:
from ... import PreTrainedTokenizerBase, TensorType
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -119,3 +126,97 @@ class VisionEncoderDecoderConfig(PretrainedConfig): ...@@ -119,3 +126,97 @@ class VisionEncoderDecoderConfig(PretrainedConfig):
output["decoder"] = self.decoder.to_dict() output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
return output return output
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict({"last_hidden_state": {0: "batch", 1: "encoder_sequence"}})
class VisionEncoderDecoderDecoderOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict()
common_inputs["input_ids"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
common_inputs["encoder_hidden_states"] = {0: "batch", 1: "encoder_sequence"}
return common_inputs
def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizerBase",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
import torch
common_inputs = OrderedDict()
dummy_input = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
batch, encoder_sequence = dummy_input["input_ids"].shape
encoder_hidden_states_shape = (batch, encoder_sequence, self._config.encoder_hidden_size)
common_inputs["input_ids"] = dummy_input.pop("input_ids")
common_inputs["attention_mask"] = dummy_input.pop("attention_mask")
common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape)
return common_inputs
class VisionEncoderDecoderOnnxConfig(OnnxConfig):
@property
def inputs(self) -> None:
pass
def get_encoder_config(self, encoder_config: PretrainedConfig) -> OnnxConfig:
r"""
Returns ONNX encoder config for `VisionEncoderDecoder` model.
Args:
encoder_config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.
Returns:
[`VisionEncoderDecoderEncoderOnnxConfig`]: An instance of the ONNX configuration object
"""
return VisionEncoderDecoderEncoderOnnxConfig(encoder_config)
def get_decoder_config(
self, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, feature: str = "default"
) -> OnnxConfig:
r"""
Returns ONNX decoder config for `VisionEncoderDecoder` model.
Args:
encoder_config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.
decoder_config (`PretrainedConfig`):
The decoder model's configuration to use when exporting to ONNX
feature (`str`, *optional*):
The type of feature to export the model with.
Returns:
[`VisionEncoderDecoderDecoderOnnxConfig`]: An instance of the ONNX configuration object.
"""
decoder_config.encoder_hidden_size = encoder_config.hidden_size
return VisionEncoderDecoderDecoderOnnxConfig(decoder_config, feature)
...@@ -22,6 +22,9 @@ from .convert import export, validate_model_outputs ...@@ -22,6 +22,9 @@ from .convert import export, validate_model_outputs
from .features import FeaturesManager from .features import FeaturesManager
ENCODER_DECODER_MODELS = ["vision-encoder-decoder"]
def main(): def main():
parser = ArgumentParser("Hugging Face Transformers ONNX exporter") parser = ArgumentParser("Hugging Face Transformers ONNX exporter")
parser.add_argument( parser.add_argument(
...@@ -65,6 +68,75 @@ def main(): ...@@ -65,6 +68,75 @@ def main():
if not args.output.parent.exists(): if not args.output.parent.exists():
args.output.parent.mkdir(parents=True) args.output.parent.mkdir(parents=True)
# Allocate the model
model = FeaturesManager.get_model_from_feature(
args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir
)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)
if model_kind in ENCODER_DECODER_MODELS:
encoder_model = model.get_encoder()
decoder_model = model.get_decoder()
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
decoder_onnx_config = onnx_config.get_decoder_config(
encoder_model.config, decoder_model.config, feature=args.feature
)
if args.opset is None:
args.opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
if args.opset < min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset):
raise ValueError(
f"Opset {args.opset} is not sufficient to export {model_kind}. At least "
f" {min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)} is required."
)
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
onnx_inputs, onnx_outputs = export(
preprocessor,
encoder_model,
encoder_onnx_config,
args.opset,
args.output.parent.joinpath("encoder_model.onnx"),
)
validate_model_outputs(
encoder_onnx_config,
preprocessor,
encoder_model,
args.output.parent.joinpath("encoder_model.onnx"),
onnx_outputs,
args.atol if args.atol else encoder_onnx_config.atol_for_validation,
)
preprocessor = AutoTokenizer.from_pretrained(args.model)
onnx_inputs, onnx_outputs = export(
preprocessor,
decoder_model,
decoder_onnx_config,
args.opset,
args.output.parent.joinpath("decoder_model.onnx"),
)
validate_model_outputs(
decoder_onnx_config,
preprocessor,
decoder_model,
args.output.parent.joinpath("decoder_model.onnx"),
onnx_outputs,
args.atol if args.atol else decoder_onnx_config.atol_for_validation,
)
logger.info(
f"All good, model saved at: {args.output.parent.joinpath('encoder_model.onnx').as_posix()},"
f" {args.output.parent.joinpath('decoder_model.onnx').as_posix()}"
)
else:
# Instantiate the appropriate preprocessor # Instantiate the appropriate preprocessor
if args.preprocessor == "auto": if args.preprocessor == "auto":
preprocessor = get_preprocessor(args.model) preprocessor = get_preprocessor(args.model)
...@@ -77,13 +149,6 @@ def main(): ...@@ -77,13 +149,6 @@ def main():
else: else:
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'") raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
# Allocate the model
model = FeaturesManager.get_model_from_feature(
args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir
)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient # Ensure the requested opset is sufficient
if args.opset is None: if args.opset is None:
args.opset = onnx_config.default_onnx_opset args.opset = onnx_config.default_onnx_opset
......
...@@ -103,6 +103,7 @@ class OnnxConfig(ABC): ...@@ -103,6 +103,7 @@ class OnnxConfig(ABC):
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}), "sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
} }
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None): def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
...@@ -451,7 +452,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -451,7 +452,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
# TODO: should we set seq_length = 1 when self.use_past = True? # TODO: should we set seq_length = 1 when self.use_past = True?
common_inputs = super().generate_dummy_inputs( common_inputs = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
...@@ -577,7 +577,6 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): ...@@ -577,7 +577,6 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
......
...@@ -30,6 +30,7 @@ if is_torch_available(): ...@@ -30,6 +30,7 @@ if is_torch_available():
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelForVision2Seq,
) )
if is_tf_available(): if is_tf_available():
from transformers.models.auto import ( from transformers.models.auto import (
...@@ -98,6 +99,7 @@ class FeaturesManager: ...@@ -98,6 +99,7 @@ class FeaturesManager:
"image-segmentation": AutoModelForImageSegmentation, "image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling, "masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation, "semantic-segmentation": AutoModelForSemanticSegmentation,
"vision2seq-lm": AutoModelForVision2Seq,
} }
if is_tf_available(): if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = { _TASKS_TO_TF_AUTOMODELS = {
...@@ -481,6 +483,9 @@ class FeaturesManager: ...@@ -481,6 +483,9 @@ class FeaturesManager:
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls="models.t5.T5OnnxConfig", onnx_config_cls="models.t5.T5OnnxConfig",
), ),
"vision-encoder-decoder": supported_features_mapping(
"vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
),
"vit": supported_features_mapping( "vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig" "default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
), ),
...@@ -582,6 +587,7 @@ class FeaturesManager: ...@@ -582,6 +587,7 @@ class FeaturesManager:
raise KeyError( raise KeyError(
f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
) )
return task_to_automodel[task] return task_to_automodel[task]
@staticmethod @staticmethod
......
...@@ -161,7 +161,6 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -161,7 +161,6 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
""" """
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS: for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name): with self.subTest(name):
# without past # without past
onnx_config_default = OnnxConfigWithPast.from_model_config(config()) onnx_config_default = OnnxConfigWithPast.from_model_config(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None") self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
...@@ -220,6 +219,10 @@ PYTORCH_EXPORT_MODELS = { ...@@ -220,6 +219,10 @@ PYTORCH_EXPORT_MODELS = {
("swin", "microsoft/swin-tiny-patch4-window7-224"), ("swin", "microsoft/swin-tiny-patch4-window7-224"),
} }
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
("vision-encoder-decoder", "nlpconnect/vit-gpt2-image-captioning"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
("bloom", "bigscience/bloom-560m"), ("bloom", "bigscience/bloom-560m"),
("gpt2", "gpt2"), ("gpt2", "gpt2"),
...@@ -347,6 +350,70 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -347,6 +350,70 @@ class OnnxExportTestCaseV2(TestCase):
except (RuntimeError, ValueError) as e: except (RuntimeError, ValueError) as e:
self.fail(f"{name}, {feature} -> {e}") self.fail(f"{name}, {feature} -> {e}")
def _onnx_export_encoder_decoder_models(
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"
):
from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.onnx import export
model_class = FeaturesManager.get_model_class_for_feature(feature)
config = AutoConfig.from_pretrained(model_name)
model = model_class.from_config(config)
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
from transformers.utils import torch_version
if torch_version < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
)
encoder_model = model.get_encoder()
decoder_model = model.get_decoder()
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
decoder_onnx_config = onnx_config.get_decoder_config(encoder_model.config, decoder_model.config, feature)
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
onnx_opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
with NamedTemporaryFile("w") as encoder_output:
onnx_inputs, onnx_outputs = export(
preprocessor, encoder_model, encoder_onnx_config, onnx_opset, Path(encoder_output.name), device=device
)
validate_model_outputs(
encoder_onnx_config,
preprocessor,
encoder_model,
Path(encoder_output.name),
onnx_outputs,
encoder_onnx_config.atol_for_validation,
)
preprocessor = AutoTokenizer.from_pretrained(model_name)
with NamedTemporaryFile("w") as decoder_output:
onnx_inputs, onnx_outputs = export(
preprocessor,
decoder_model,
decoder_onnx_config,
onnx_config.default_onnx_opset,
Path(decoder_output.name),
device=device,
)
validate_model_outputs(
decoder_onnx_config,
preprocessor,
decoder_model,
Path(decoder_output.name),
onnx_outputs,
decoder_onnx_config.atol_for_validation,
)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow @slow
@require_torch @require_torch
...@@ -363,6 +430,28 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -363,6 +430,28 @@ class OnnxExportTestCaseV2(TestCase):
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor): def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda") self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
@slow
@require_torch
@require_vision
@require_rjieba
def test_pytorch_export_encoder_decoder_models(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export_encoder_decoder_models(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
@slow
@require_torch
@require_vision
@require_rjieba
def test_pytorch_export_encoder_decoder_models_on_cuda(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export_encoder_decoder_models(
test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda"
)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow @slow
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment