Unverified Commit ddd4d02f authored by Nishant Prabhu's avatar Nishant Prabhu Committed by GitHub
Browse files

Layoutlm onnx support (Issue #13300) (#13562)



* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Removed regression/ folder

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Fixed import error

* Remove unnecessary import statements

* Changed max_2d_positions from class variable to instance variable of the config class

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Add support for exporting PyTorch LayoutLM to ONNX

* cleanup

* Fixed import error

* Changed max_2d_positions from class variable to instance variable of the config class

* Use super class generate_dummy_inputs method
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>

* Add support for Masked LM, sequence classification and token classification
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>

* Removed uncessary import and method

* Fixed code styling

* Raise error if PyTorch is not installed

* Remove unnecessary import statement
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent b7d264be
...@@ -24,7 +24,7 @@ from .tokenization_layoutlm import LayoutLMTokenizer ...@@ -24,7 +24,7 @@ from .tokenization_layoutlm import LayoutLMTokenizer
_import_structure = { _import_structure = {
"configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig"], "configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMOnnxConfig"],
"tokenization_layoutlm": ["LayoutLMTokenizer"], "tokenization_layoutlm": ["LayoutLMTokenizer"],
} }
...@@ -54,7 +54,7 @@ if is_tf_available(): ...@@ -54,7 +54,7 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMOnnxConfig
from .tokenization_layoutlm import LayoutLMTokenizer from .tokenization_layoutlm import LayoutLMTokenizer
if is_tokenizers_available(): if is_tokenizers_available():
......
...@@ -13,8 +13,13 @@ ...@@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" LayoutLM model configuration """ """ LayoutLM model configuration """
from collections import OrderedDict
from typing import Any, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
from ... import is_torch_available
from ...onnx import OnnxConfig, PatchingSpec
from ...utils import logging from ...utils import logging
from ..bert.configuration_bert import BertConfig from ..bert.configuration_bert import BertConfig
...@@ -125,3 +130,68 @@ class LayoutLMConfig(BertConfig): ...@@ -125,3 +130,68 @@ class LayoutLMConfig(BertConfig):
**kwargs, **kwargs,
) )
self.max_2d_position_embeddings = max_2d_position_embeddings self.max_2d_position_embeddings = max_2d_position_embeddings
class LayoutLMOnnxConfig(OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.max_2d_positions = config.max_2d_position_embeddings - 1
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
tokenizer: The tokenizer associated with this model configuration
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
framework: The framework (optional) the tokenizer will generate tensor for
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# Generate a dummy bbox
box = [48, 84, 73, 128]
if not framework == TensorType.PYTORCH:
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
if not is_torch_available():
raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
import torch
input_dict["bbox"] = torch.tensor(
[
[0] * 4,
*[box] * seq_length,
[self.max_2d_positions] * 4,
]
).tile(batch_size, 1, 1)
return input_dict
...@@ -8,6 +8,7 @@ from ..models.bert import BertOnnxConfig ...@@ -8,6 +8,7 @@ from ..models.bert import BertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig from ..models.longformer import LongformerOnnxConfig
from ..models.mbart import MBartOnnxConfig from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
...@@ -78,6 +79,13 @@ class FeaturesManager: ...@@ -78,6 +79,13 @@ class FeaturesManager:
"sequence-classification-with-past", "sequence-classification-with-past",
onnx_config_cls=GPTNeoOnnxConfig, onnx_config_cls=GPTNeoOnnxConfig,
), ),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
onnx_config_cls=LayoutLMOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
......
...@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config, ...@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
DistilBertConfig, DistilBertConfig,
GPT2Config, GPT2Config,
GPTNeoConfig, GPTNeoConfig,
LayoutLMConfig,
MBartConfig, MBartConfig,
RobertaConfig, RobertaConfig,
XLMRobertaConfig, XLMRobertaConfig,
...@@ -23,6 +24,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig ...@@ -23,6 +24,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig # from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.layoutlm import LayoutLMOnnxConfig
from transformers.models.mbart import MBartOnnxConfig from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig from transformers.models.roberta import RobertaOnnxConfig
...@@ -193,6 +195,7 @@ if is_torch_available(): ...@@ -193,6 +195,7 @@ if is_torch_available():
DistilBertModel, DistilBertModel,
GPT2Model, GPT2Model,
GPTNeoModel, GPTNeoModel,
LayoutLMModel,
MBartModel, MBartModel,
RobertaModel, RobertaModel,
XLMRobertaModel, XLMRobertaModel,
...@@ -208,6 +211,7 @@ if is_torch_available(): ...@@ -208,6 +211,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig), ("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment