Unverified Commit ddd4d02f authored by Nishant Prabhu's avatar Nishant Prabhu Committed by GitHub
Browse files

Layoutlm onnx support (Issue #13300) (#13562)



* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Removed regression/ folder

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Fixed import error

* Remove unnecessary import statements

* Changed max_2d_positions from class variable to instance variable of the config class

* Add support for exporting PyTorch LayoutLM to ONNX

* Added tests for converting LayoutLM to ONNX

* cleanup

* Add support for exporting PyTorch LayoutLM to ONNX

* cleanup

* Fixed import error

* Changed max_2d_positions from class variable to instance variable of the config class

* Use super class generate_dummy_inputs method
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>

* Add support for Masked LM, sequence classification and token classification
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>

* Removed uncessary import and method

* Fixed code styling

* Raise error if PyTorch is not installed

* Remove unnecessary import statement
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent b7d264be
......@@ -24,7 +24,7 @@ from .tokenization_layoutlm import LayoutLMTokenizer
_import_structure = {
"configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig"],
"configuration_layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMOnnxConfig"],
"tokenization_layoutlm": ["LayoutLMTokenizer"],
}
......@@ -54,7 +54,7 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMOnnxConfig
from .tokenization_layoutlm import LayoutLMTokenizer
if is_tokenizers_available():
......
......@@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" LayoutLM model configuration """
from collections import OrderedDict
from typing import Any, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
from ... import is_torch_available
from ...onnx import OnnxConfig, PatchingSpec
from ...utils import logging
from ..bert.configuration_bert import BertConfig
......@@ -125,3 +130,68 @@ class LayoutLMConfig(BertConfig):
**kwargs,
)
self.max_2d_position_embeddings = max_2d_position_embeddings
class LayoutLMOnnxConfig(OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.max_2d_positions = config.max_2d_position_embeddings - 1
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
tokenizer: The tokenizer associated with this model configuration
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
framework: The framework (optional) the tokenizer will generate tensor for
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# Generate a dummy bbox
box = [48, 84, 73, 128]
if not framework == TensorType.PYTORCH:
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
if not is_torch_available():
raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
import torch
input_dict["bbox"] = torch.tensor(
[
[0] * 4,
*[box] * seq_length,
[self.max_2d_positions] * 4,
]
).tile(batch_size, 1, 1)
return input_dict
......@@ -8,6 +8,7 @@ from ..models.bert import BertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
......@@ -78,6 +79,13 @@ class FeaturesManager:
"sequence-classification-with-past",
onnx_config_cls=GPTNeoOnnxConfig,
),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
onnx_config_cls=LayoutLMOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
......
......@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
LayoutLMConfig,
MBartConfig,
RobertaConfig,
XLMRobertaConfig,
......@@ -23,6 +24,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.layoutlm import LayoutLMOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
......@@ -193,6 +195,7 @@ if is_torch_available():
DistilBertModel,
GPT2Model,
GPTNeoModel,
LayoutLMModel,
MBartModel,
RobertaModel,
XLMRobertaModel,
......@@ -208,6 +211,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment