Unverified Commit 4684bfc7 authored by Amir Tahmasbi's avatar Amir Tahmasbi Committed by GitHub
Browse files

Layout lm tf 2 (#10636)



* Added embeddings layer

* Added layoutlm layers, main model, maskedlm and token classification classes

* Added model classes to tf auto models

* Added model to PT to TF conversion script

* Added model to doc README

* Added tests

* Removed unused imports

* Added layoutlm model, test, and doc for sequence classification, and fix imports in __init__.py

* Made tests pass!

* Fixed typos in imports and docs

* Fixed a typo in embeddings layer

* Removed imports

* Fixed formatting issues, imports, tests

* Added layoutlm layers, main model, maskedlm and token classification classes

* Added model classes to tf auto models

* Added model to PT to TF conversion script

* Removed unused imports

* Added layoutlm model, test, and doc for sequence classification, and fix imports in __init__.py

* Made tests pass!

* Fixed typos in imports and docs

* Removed imports

* Fixed small formatting issues

* Removed duplicates import from main __init__.py

* Chnaged deafult arg to true for adding  pooling layer to tf layoutlm

* Fixed formatting issues

* Style

* Added copied from to classes copied from bert

* Fixed doc strings examples to work with layoutlm inputs

* Removed PyTorch reference in doc strings example

* Added integration tests

* Cleaned up initialization file

* Updated model checkpoint identifiers

* Fixed imports
Co-authored-by: default avatarAmir Tahmasbi <amir@ehsai.ca>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 1a3e0c4f
......@@ -281,7 +281,7 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LayoutLM | ✅ | ✅ | ✅ | | ❌ |
| LayoutLM | ✅ | ✅ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -130,3 +130,31 @@ LayoutLMForTokenClassification
.. autoclass:: transformers.LayoutLMForTokenClassification
:members:
TFLayoutLMModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLayoutLMModel
:members:
TFLayoutLMForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLayoutLMForMaskedLM
:members:
TFLayoutLMForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLayoutLMForSequenceClassification
:members:
TFLayoutLMForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLayoutLMForTokenClassification
:members:
......@@ -1214,6 +1214,17 @@ if is_tf_available():
"TFXLMRobertaModel",
]
)
_import_structure["models.layoutlm"].extend(
[
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLayoutLMForMaskedLM",
"TFLayoutLMForSequenceClassification",
"TFLayoutLMForTokenClassification",
"TFLaoutLMMainLayer",
"TFLayoutLMModel",
"TFLayoutLMPreTrainedModel",
]
)
_import_structure["models.xlnet"].extend(
[
"TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -2010,6 +2021,15 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_utils import tf_top_k_top_p_filtering
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMMainLayer,
TFLayoutLMModel,
TFLayoutLMPreTrainedModel,
)
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, shape_list
from .models.albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
......
......@@ -31,6 +31,7 @@ from . import (
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -50,6 +51,7 @@ from . import (
ElectraConfig,
FlaubertConfig,
GPT2Config,
LayoutLMConfig,
LxmertConfig,
OpenAIGPTConfig,
RobertaConfig,
......@@ -69,6 +71,7 @@ from . import (
TFElectraForPreTraining,
TFFlaubertWithLMHeadModel,
TFGPT2LMHeadModel,
TFLayoutLMForMaskedLM,
TFLxmertForPreTraining,
TFLxmertVisualFeatureEncoder,
TFOpenAIGPTLMHeadModel,
......@@ -111,6 +114,7 @@ if is_torch_available():
ElectraForPreTraining,
FlaubertWithLMHeadModel,
GPT2LMHeadModel,
LayoutLMForMaskedLM,
LxmertForPreTraining,
LxmertVisualFeatureEncoder,
OpenAIGPTLMHeadModel,
......@@ -211,6 +215,12 @@ MODEL_CLASSES = {
RobertaForMaskedLM,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"layoutlm": (
LayoutLMConfig,
TFLayoutLMForMaskedLM,
LayoutLMForMaskedLM,
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"roberta-large-mnli": (
RobertaConfig,
TFRobertaForSequenceClassification,
......
......@@ -333,6 +333,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights = set(list(tf_weights_map.keys()))
loaded_pt_weights_data_ptr = {}
missing_keys_pt = []
for pt_weight_name, pt_weight in current_pt_params_dict.items():
# Handle PyTorch shared weight ()not duplicated in TF 2.0
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
......
......@@ -102,6 +102,12 @@ from ..funnel.modeling_tf_funnel import (
TFFunnelModel,
)
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..layoutlm.modeling_tf_layoutlm import (
TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMModel,
)
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel
from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM,
......@@ -189,6 +195,7 @@ from .configuration_auto import (
FlaubertConfig,
FunnelConfig,
GPT2Config,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LxmertConfig,
......@@ -227,6 +234,7 @@ TF_MODEL_MAPPING = OrderedDict(
(XLMRobertaConfig, TFXLMRobertaModel),
(LongformerConfig, TFLongformerModel),
(RobertaConfig, TFRobertaModel),
(LayoutLMConfig, TFLayoutLMModel),
(BertConfig, TFBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(GPT2Config, TFGPT2Model),
......@@ -260,6 +268,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForPreTraining),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
......@@ -289,6 +298,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
......@@ -330,6 +340,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(MobileBertConfig, TFMobileBertForMaskedLM),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
......@@ -366,6 +377,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(LongformerConfig, TFLongformerForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(LayoutLMConfig, TFLayoutLMForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
(MobileBertConfig, TFMobileBertForSequenceClassification),
......@@ -414,6 +426,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(LongformerConfig, TFLongformerForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(LayoutLMConfig, TFLayoutLMForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification),
(XLNetConfig, TFXLNetForTokenClassification),
......
......@@ -18,7 +18,9 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from .tokenization_layoutlm import LayoutLMTokenizer
_import_structure = {
......@@ -38,6 +40,17 @@ if is_torch_available():
"LayoutLMModel",
]
if is_tf_available():
_import_structure["modeling_tf_layoutlm"] = [
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLayoutLMForMaskedLM",
"TFLayoutLMForTokenClassification",
"TFLayoutLMForSequenceClassification",
"TFLayoutLMMainLayer",
"TFLayoutLMModel",
"TFLayoutLMPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
......@@ -54,6 +67,16 @@ if TYPE_CHECKING:
LayoutLMForTokenClassification,
LayoutLMModel,
)
if is_tf_available():
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMMainLayer,
TFLayoutLMModel,
TFLayoutLMPreTrainedModel,
)
else:
import importlib
......
# coding=utf-8
# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 LayoutLM model. """
import math
import warnings
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging
from .configuration_layoutlm import LayoutLMConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LayoutLMConfig"
_TOKENIZER_FOR_DOC = "LayoutLMTokenizer"
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/layoutlm-base-uncased",
"microsoft/layoutlm-large-uncased",
]
class TFLayoutLMEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.type_vocab_size = config.type_vocab_size
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.max_2d_position_embeddings = config.max_2d_position_embeddings
self.initializer_range = config.initializer_range
self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("x_position_embeddings"):
self.x_position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_2d_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("y_position_embeddings"):
self.y_position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_2d_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("h_position_embeddings"):
self.h_position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_2d_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("w_position_embeddings"):
self.w_position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_2d_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def call(
self,
input_ids: tf.Tensor = None,
bbox: tf.Tensor = None,
position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
training: bool = False,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.
Returns:
final_embeddings (:obj:`tf.Tensor`): output embedding tensor.
"""
assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None:
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
if bbox is None:
bbox = bbox = tf.fill(input_shape + [4], value=0)
try:
left_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 0])
upper_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 1])
right_position_embeddings = tf.gather(self.x_position_embeddings, bbox[:, :, 2])
lower_position_embeddings = tf.gather(self.y_position_embeddings, bbox[:, :, 3])
except IndexError as e:
raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e
h_position_embeddings = tf.gather(self.h_position_embeddings, bbox[:, :, 3] - bbox[:, :, 1])
w_position_embeddings = tf.gather(self.w_position_embeddings, bbox[:, :, 2] - bbox[:, :, 0])
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(
inputs=[
inputs_embeds,
position_embeds,
token_type_embeds,
left_position_embeddings,
upper_position_embeddings,
right_position_embeddings,
lower_position_embeddings,
h_position_embeddings,
w_position_embeddings,
]
)
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->LayoutLM
class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFLayoutLMModel call() function)
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(inputs=attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->LayoutLM
class TFLayoutLMSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->LayoutLM
class TFLayoutLMAttention(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFLayoutLMSelfAttention(config, name="self")
self.dense_output = TFLayoutLMSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
input_tensor: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->LayoutLM
class TFLayoutLMIntermediate(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->LayoutLM
class TFLayoutLMOutput(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->LayoutLM
class TFLayoutLMLayer(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.attention = TFLayoutLMAttention(config, name="attention")
self.intermediate = TFLayoutLMIntermediate(config, name="intermediate")
self.bert_output = TFLayoutLMOutput(config, name="output")
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
attention_outputs = self.attention(
input_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->LayoutLM
class TFLayoutLMEncoder(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.layer = [TFLayoutLMLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->LayoutLM
class TFLayoutLMPooler(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->LayoutLM
class TFLayoutLMPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.transform_act_fn = get_tf_activation(config.hidden_act)
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(inputs=hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->LayoutLM
class TFLayoutLMLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.transform = TFLayoutLMPredictionHeadTransform(config, name="transform")
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
def set_output_embeddings(self, value: tf.Variable):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self) -> Dict[str, tf.Variable]:
return {"bias": self.bias}
def set_bias(self, value: tf.Variable):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.transform(hidden_states=hidden_states)
seq_length = shape_list(hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.vocab_size])
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->LayoutLM
class TFLayoutLMMLMHead(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
super().__init__(**kwargs)
self.predictions = TFLayoutLMLMPredictionHead(config, input_embeddings, name="predictions")
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
prediction_scores = self.predictions(hidden_states=sequence_output)
return prediction_scores
@keras_serializable
class TFLayoutLMMainLayer(tf.keras.layers.Layer):
config_class = LayoutLMConfig
def __init__(self, config: LayoutLMConfig, add_pooling_layer: bool = True, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embeddings = TFLayoutLMEmbeddings(config, name="embeddings")
self.encoder = TFLayoutLMEncoder(config, name="encoder")
self.pooler = TFLayoutLMPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings
def set_input_embeddings(self, value: tf.Variable):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
def call(
self,
input_ids: Optional[TFModelInputType] = None,
bbox: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
if inputs["bbox"] is None:
inputs["bbox"] = tf.fill(dims=input_shape + [4], value=0)
embedding_output = self.embeddings(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
position_ids=inputs["position_ids"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder(
hidden_states=embedding_output,
attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]:
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class TFLayoutLMPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LayoutLMConfig
base_model_prefix = "layoutlm"
LAYOUTLM_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
and behavior.
.. note::
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all
the tensors in the first argument of the model call function: :obj:`model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in
the first positional argument :
- a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
:obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Args:
config (:class:`~transformers.LayoutLMConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the
model weights.
"""
LAYOUTLM_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.LayoutLMTokenizer`. See
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
bbox (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0}, 4)`, `optional`):
Bounding Boxes of each input sequence tokens. Selected in the range ``[0,
config.max_2d_position_embeddings- 1]``.
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`__
head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@add_start_docstrings(
"The bare LayoutLM Model transformer outputting raw hidden-states without any specific head on top.",
LAYOUTLM_START_DOCSTRING,
)
class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids: Optional[TFModelInputType] = None,
bbox: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
r"""
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, TFLayoutLMModel
>>> import tensorflow as tf
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = TFLayoutLMModel.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "world"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="tf")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = tf.convert_to_tensor([token_boxes])
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top. """, LAYOUTLM_START_DOCSTRING)
class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"cls.seq_relationship",
r"cls.predictions.decoder.weight",
r"nsp___cls",
]
def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if config.is_decoder:
logger.warning(
"If you want to use `TFLayoutLMForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm")
self.mlm = TFLayoutLMMLMHead(config, input_embeddings=self.layoutlm.embeddings, name="mlm___cls")
def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions
def get_prefix_bias_name(self) -> str:
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids: Optional[TFModelInputType] = None,
bbox: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
r"""
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, TFLayoutLMForMaskedLM
>>> import tensorflow as tf
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = TFLayoutLMForMaskedLM.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "[MASK]"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="tf")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = tf.convert_to_tensor([token_boxes])
>>> labels = tokenizer("Hello world", return_tensors="tf")["input_ids"]
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
... labels=labels)
>>> loss = outputs.loss
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
loss = (
None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=prediction_scores)
)
if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""
LayoutLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
LAYOUTLM_START_DOCSTRING,
)
class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceClassificationLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
units=config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids: Optional[TFModelInputType] = None,
bbox: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, TFLayoutLMForSequenceClassification
>>> import tensorflow as tf
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = TFLayoutLMForSequenceClassification.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "world"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="tf")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = tf.convert_to_tensor([token_boxes])
>>> sequence_label = tf.convert_to_tensor([1])
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
... labels=sequence_label)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
logits = self.classifier(inputs=pooled_output)
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
if not inputs["return_dict"]:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""
LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
LAYOUTLM_START_DOCSTRING,
)
class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassificationLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
units=config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids: Optional[TFModelInputType] = None,
bbox: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
Returns:
Examples::
>>> from transformers import LayoutLMTokenizer, TFLayoutLMForTokenClassification
>>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
>>> model = TFLayoutLMForTokenClassification.from_pretrained('microsoft/layoutlm-base-uncased')
>>> words = ["Hello", "world"]
>>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
>>> token_boxes = []
>>> for word, box in zip(words, normalized_word_boxes):
... word_tokens = tokenizer.tokenize(word)
... token_boxes.extend([box] * len(word_tokens))
>>> # add bounding boxes of cls + sep tokens
>>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
>>> encoding = tokenizer(' '.join(words), return_tensors="tf")
>>> input_ids = encoding["input_ids"]
>>> attention_mask = encoding["attention_mask"]
>>> token_type_ids = encoding["token_type_ids"]
>>> bbox = tf.convert_to_tensor([token_boxes])
>>> token_labels = tf.convert_to_tensor([1,1,0,0])
>>> outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
... labels=token_labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
logits = self.classifier(inputs=sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
if not inputs["return_dict"]:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
......@@ -16,6 +16,59 @@ def tf_top_k_top_p_filtering(*args, **kwargs):
requires_tf(tf_top_k_top_p_filtering)
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFLayoutLMForMaskedLM:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLayoutLMForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLayoutLMForTokenClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLayoutLMMainLayer:
def __init__(self, *args, **kwargs):
requires_tf(self)
class TFLayoutLMModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLayoutLMPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
......
# coding=utf-8
# Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors, The Hugging Face Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import LayoutLMConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers.models.layoutlm.modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMModel,
)
class TFLayoutLMModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
range_bbox=1000,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.range_bbox = range_bbox
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
# convert bbox to numpy since TF does not support item assignment
bbox = ids_tensor([self.batch_size, self.seq_length, 4], self.range_bbox).numpy()
# Ensure that bbox is legal
for i in range(bbox.shape[0]):
for j in range(bbox.shape[1]):
if bbox[i, j, 3] < bbox[i, j, 1]:
t = bbox[i, j, 3]
bbox[i, j, 3] = bbox[i, j, 1]
bbox[i, j, 1] = t
if bbox[i, j, 2] < bbox[i, j, 0]:
t = bbox[i, j, 2]
bbox[i, j, 2] = bbox[i, j, 0]
bbox[i, j, 0] = t
bbox = tf.convert_to_tensor(bbox)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LayoutLMConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
)
return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLayoutLMModel(config=config)
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, bbox, token_type_ids=token_type_ids)
result = model(input_ids, bbox)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_for_masked_lm(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLayoutLMForMaskedLM(config=config)
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLayoutLMForSequenceClassification(config=config)
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLayoutLMForTokenClassification(config=config)
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
bbox,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"bbox": bbox,
"token_type_ids": token_type_ids,
"attention_mask": input_mask,
}
return config, inputs_dict
@require_tf
class LayoutLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFLayoutLMModel, TFLayoutLMForMaskedLM, TFLayoutLMForTokenClassification, TFLayoutLMForSequenceClassification)
if is_tf_available()
else ()
)
test_head_masking = False
test_onnx = True
onnx_min_opset = 10
def setUp(self):
self.model_tester = TFLayoutLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=LayoutLMConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFLayoutLMModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def prepare_layoutlm_batch_inputs():
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
# fmt: off
input_ids = tf.convert_to_tensor([[101,1019,1014,1016,1037,12849,4747,1004,14246,2278,5439,4524,5002,2930,2193,2930,4341,3208,1005,1055,2171,2848,11300,3531,102],[101,4070,4034,7020,1024,3058,1015,1013,2861,1013,6070,19274,2772,6205,27814,16147,16147,4343,2047,10283,10969,14389,1012,2338,102]]) # noqa: E231
attention_mask = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],]) # noqa: E231
bbox = tf.convert_to_tensor([[[0,0,0,0],[423,237,440,251],[427,272,441,287],[419,115,437,129],[961,885,992,912],[256,38,330,58],[256,38,330,58],[336,42,353,57],[360,39,401,56],[360,39,401,56],[411,39,471,59],[479,41,528,59],[533,39,630,60],[67,113,134,131],[141,115,209,132],[68,149,133,166],[141,149,187,164],[195,148,287,165],[195,148,287,165],[195,148,287,165],[295,148,349,165],[441,149,492,166],[497,149,546,164],[64,201,125,218],[1000,1000,1000,1000]],[[0,0,0,0],[662,150,754,166],[665,199,742,211],[519,213,554,228],[519,213,554,228],[134,433,187,454],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[314,469,376,482],[504,684,582,706],[941,825,973,900],[941,825,973,900],[941,825,973,900],[941,825,973,900],[610,749,652,765],[130,659,168,672],[176,657,237,672],[238,657,312,672],[443,653,628,672],[443,653,628,672],[716,301,825,317],[1000,1000,1000,1000]]]) # noqa: E231
token_type_ids = tf.convert_to_tensor([[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]) # noqa: E231
# these are sequence labels (i.e. at the token level)
labels = tf.convert_to_tensor([[-100,10,10,10,9,1,-100,7,7,-100,7,7,4,2,5,2,8,8,-100,-100,5,0,3,2,-100],[-100,12,12,12,-100,12,10,-100,-100,-100,-100,10,12,9,-100,-100,-100,10,10,10,9,12,-100,10,-100]]) # noqa: E231
# fmt: on
return input_ids, attention_mask, bbox, token_type_ids, labels
@require_tf
class TFLayoutLMModelIntegrationTest(unittest.TestCase):
@slow
def test_forward_pass_no_head(self):
model = TFLayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
# forward pass
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
# test the sequence output on [0, :3, :3]
expected_slice = tf.convert_to_tensor(
[[0.1785, -0.1947, -0.0425], [-0.3254, -0.2807, 0.2553], [-0.5391, -0.3322, 0.3364]],
)
self.assertTrue(np.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-3))
# test the pooled output on [1, :3]
expected_slice = tf.convert_to_tensor([-0.6580, -0.0214, 0.8552])
self.assertTrue(np.allclose(outputs.pooler_output[1, :3], expected_slice, atol=1e-3))
@slow
def test_forward_pass_sequence_classification(self):
# initialize model with randomly initialized sequence classification head
model = TFLayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=2)
input_ids, attention_mask, bbox, token_type_ids, _ = prepare_layoutlm_batch_inputs()
# forward pass
outputs = model(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=tf.convert_to_tensor([1, 1]),
)
# test whether we get a loss as a scalar
loss = outputs.loss
expected_shape = (2,)
self.assertEqual(loss.shape, expected_shape)
# test the shape of the logits
logits = outputs.logits
expected_shape = (2, 2)
self.assertEqual(logits.shape, expected_shape)
@slow
def test_forward_pass_token_classification(self):
# initialize model with randomly initialized token classification head
model = TFLayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=13)
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
# forward pass
outputs = model(
input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels
)
# test the shape of the logits
logits = outputs.logits
expected_shape = tf.convert_to_tensor((2, 25, 13))
self.assertEqual(logits.shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment