Unverified Commit 5c4c8690 authored by Ankur Goyal's avatar Ankur Goyal Committed by GitHub
Browse files

Add LayoutLMForQuestionAnswering model (#18407)



* Add LayoutLMForQuestionAnswering model

* Fix output

* Remove TF TODOs

* Add test cases

* Add docs

* TF implementation

* Fix PT/TF equivalence

* Fix loss

* make fixup

* Fix up documentation code examples

* Fix up documentation examples + test them

* Remove LayoutLMForQuestionAnswering from the auto mapping

* Docstrings

* Add better docstrings

* Undo whitespace changes

* Update tokenizers in comments

* Fixup code and remove `from_pt=True`

* Fix tests

* Revert some unexpected docstring changes

* Fix tests by overriding _prepare_for_class
Co-authored-by: default avatarAnkur Goyal <ankur@impira.com>
parent e88e9ff0
...@@ -107,6 +107,10 @@ This model was contributed by [liminghao1630](https://huggingface.co/liminghao16 ...@@ -107,6 +107,10 @@ This model was contributed by [liminghao1630](https://huggingface.co/liminghao16
[[autodoc]] LayoutLMForTokenClassification [[autodoc]] LayoutLMForTokenClassification
## LayoutLMForQuestionAnswering
[[autodoc]] LayoutLMForQuestionAnswering
## TFLayoutLMModel ## TFLayoutLMModel
[[autodoc]] TFLayoutLMModel [[autodoc]] TFLayoutLMModel
...@@ -122,3 +126,7 @@ This model was contributed by [liminghao1630](https://huggingface.co/liminghao16 ...@@ -122,3 +126,7 @@ This model was contributed by [liminghao1630](https://huggingface.co/liminghao16
## TFLayoutLMForTokenClassification ## TFLayoutLMForTokenClassification
[[autodoc]] TFLayoutLMForTokenClassification [[autodoc]] TFLayoutLMForTokenClassification
## TFLayoutLMForQuestionAnswering
[[autodoc]] TFLayoutLMForQuestionAnswering
...@@ -1305,6 +1305,7 @@ else: ...@@ -1305,6 +1305,7 @@ else:
"LayoutLMForMaskedLM", "LayoutLMForMaskedLM",
"LayoutLMForSequenceClassification", "LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification", "LayoutLMForTokenClassification",
"LayoutLMForQuestionAnswering",
"LayoutLMModel", "LayoutLMModel",
"LayoutLMPreTrainedModel", "LayoutLMPreTrainedModel",
] ]
...@@ -2337,6 +2338,7 @@ else: ...@@ -2337,6 +2338,7 @@ else:
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLayoutLMForMaskedLM", "TFLayoutLMForMaskedLM",
"TFLayoutLMForSequenceClassification", "TFLayoutLMForSequenceClassification",
"TFLayoutLMForQuestionAnswering",
"TFLayoutLMForTokenClassification", "TFLayoutLMForTokenClassification",
"TFLayoutLMMainLayer", "TFLayoutLMMainLayer",
"TFLayoutLMModel", "TFLayoutLMModel",
...@@ -3945,6 +3947,7 @@ if TYPE_CHECKING: ...@@ -3945,6 +3947,7 @@ if TYPE_CHECKING:
from .models.layoutlm import ( from .models.layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM, LayoutLMForMaskedLM,
LayoutLMForQuestionAnswering,
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
...@@ -4583,6 +4586,7 @@ if TYPE_CHECKING: ...@@ -4583,6 +4586,7 @@ if TYPE_CHECKING:
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM, TFLayoutLMForMaskedLM,
TFLayoutLMForQuestionAnswering,
TFLayoutLMForSequenceClassification, TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification, TFLayoutLMForTokenClassification,
TFLayoutLMMainLayer, TFLayoutLMMainLayer,
......
...@@ -51,6 +51,7 @@ else: ...@@ -51,6 +51,7 @@ else:
"LayoutLMForMaskedLM", "LayoutLMForMaskedLM",
"LayoutLMForSequenceClassification", "LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification", "LayoutLMForTokenClassification",
"LayoutLMForQuestionAnswering",
"LayoutLMModel", "LayoutLMModel",
"LayoutLMPreTrainedModel", "LayoutLMPreTrainedModel",
] ]
...@@ -66,6 +67,7 @@ else: ...@@ -66,6 +67,7 @@ else:
"TFLayoutLMForMaskedLM", "TFLayoutLMForMaskedLM",
"TFLayoutLMForSequenceClassification", "TFLayoutLMForSequenceClassification",
"TFLayoutLMForTokenClassification", "TFLayoutLMForTokenClassification",
"TFLayoutLMForQuestionAnswering",
"TFLayoutLMMainLayer", "TFLayoutLMMainLayer",
"TFLayoutLMModel", "TFLayoutLMModel",
"TFLayoutLMPreTrainedModel", "TFLayoutLMPreTrainedModel",
...@@ -93,6 +95,7 @@ if TYPE_CHECKING: ...@@ -93,6 +95,7 @@ if TYPE_CHECKING:
from .modeling_layoutlm import ( from .modeling_layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM, LayoutLMForMaskedLM,
LayoutLMForQuestionAnswering,
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
...@@ -107,6 +110,7 @@ if TYPE_CHECKING: ...@@ -107,6 +110,7 @@ if TYPE_CHECKING:
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM, TFLayoutLMForMaskedLM,
TFLayoutLMForQuestionAnswering,
TFLayoutLMForSequenceClassification, TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification, TFLayoutLMForTokenClassification,
TFLayoutLMMainLayer, TFLayoutLMMainLayer,
......
...@@ -28,6 +28,7 @@ from ...modeling_outputs import ( ...@@ -28,6 +28,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
...@@ -40,7 +41,6 @@ from .configuration_layoutlm import LayoutLMConfig ...@@ -40,7 +41,6 @@ from .configuration_layoutlm import LayoutLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LayoutLMConfig" _CONFIG_FOR_DOC = "LayoutLMConfig"
_TOKENIZER_FOR_DOC = "LayoutLMTokenizer"
_CHECKPOINT_FOR_DOC = "microsoft/layoutlm-base-uncased" _CHECKPOINT_FOR_DOC = "microsoft/layoutlm-base-uncased"
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [ LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -749,10 +749,10 @@ class LayoutLMModel(LayoutLMPreTrainedModel): ...@@ -749,10 +749,10 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import LayoutLMTokenizer, LayoutLMModel >>> from transformers import AutoTokenizer, LayoutLMModel
>>> import torch >>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "world"] >>> words = ["Hello", "world"]
...@@ -896,10 +896,10 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): ...@@ -896,10 +896,10 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import LayoutLMTokenizer, LayoutLMForMaskedLM >>> from transformers import AutoTokenizer, LayoutLMForMaskedLM
>>> import torch >>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = LayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "[MASK]"] >>> words = ["Hello", "[MASK]"]
...@@ -1018,10 +1018,10 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel): ...@@ -1018,10 +1018,10 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import LayoutLMTokenizer, LayoutLMForSequenceClassification >>> from transformers import AutoTokenizer, LayoutLMForSequenceClassification
>>> import torch >>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "world"] >>> words = ["Hello", "world"]
...@@ -1153,10 +1153,10 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): ...@@ -1153,10 +1153,10 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import LayoutLMTokenizer, LayoutLMForTokenClassification >>> from transformers import AutoTokenizer, LayoutLMForTokenClassification
>>> import torch >>> import torch
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "world"] >>> words = ["Hello", "world"]
...@@ -1222,3 +1222,147 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): ...@@ -1222,3 +1222,147 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""
LayoutLM Model with a span classification head on top for extractive question-answering tasks such as
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to compute `span
start logits` and `span end logits`).
""",
LAYOUTLM_START_DOCSTRING,
)
class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
def __init__(self, config, has_visual_segment_embedding=True):
super().__init__(config)
self.num_labels = config.num_labels
self.layoutlm = LayoutLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings
@replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
bbox: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
Returns:
Example:
In the example below, we prepare a question + context pair for the LayoutLM model. It will give us a prediction
of what it thinks the answer is (the span of the answer within the texts parsed from the image).
```python
>>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
>>> from datasets import load_dataset
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
>>> model = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa")
>>> dataset = load_dataset("nielsr/funsd", split="train")
>>> example = dataset[0]
>>> question = "what's his name?"
>>> words = example["words"]
>>> boxes = example["bboxes"]
>>> encoding = tokenizer(
... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt"
... )
>>> bbox = []
>>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
... if s == 1:
... bbox.append(boxes[w])
... elif i == tokenizer.sep_token_id:
... bbox.append([1000] * 4)
... else:
... bbox.append([0] * 4)
>>> encoding["bbox"] = torch.tensor([bbox])
>>> word_ids = encoding.word_ids(0)
>>> outputs = model(**encoding)
>>> loss = outputs.loss
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
>>> start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
>>> print(" ".join(words[start : end + 1]))
M. Hamann P. Harper, P. Martinez
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -26,6 +26,7 @@ from ...modeling_tf_outputs import ( ...@@ -26,6 +26,7 @@ from ...modeling_tf_outputs import (
TFBaseModelOutputWithPastAndCrossAttentions, TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPoolingAndCrossAttentions, TFBaseModelOutputWithPoolingAndCrossAttentions,
TFMaskedLMOutput, TFMaskedLMOutput,
TFQuestionAnsweringModelOutput,
TFSequenceClassifierOutput, TFSequenceClassifierOutput,
TFTokenClassifierOutput, TFTokenClassifierOutput,
) )
...@@ -33,6 +34,7 @@ from ...modeling_tf_utils import ( ...@@ -33,6 +34,7 @@ from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFModelInputType, TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
...@@ -47,7 +49,6 @@ from .configuration_layoutlm import LayoutLMConfig ...@@ -47,7 +49,6 @@ from .configuration_layoutlm import LayoutLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LayoutLMConfig" _CONFIG_FOR_DOC = "LayoutLMConfig"
_TOKENIZER_FOR_DOC = "LayoutLMTokenizer"
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [ TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/layoutlm-base-uncased", "microsoft/layoutlm-base-uncased",
...@@ -934,10 +935,10 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): ...@@ -934,10 +935,10 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import LayoutLMTokenizer, TFLayoutLMModel >>> from transformers import AutoTokenizer, TFLayoutLMModel
>>> import tensorflow as tf >>> import tensorflow as tf
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = TFLayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = TFLayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "world"] >>> words = ["Hello", "world"]
...@@ -1058,10 +1059,10 @@ class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingL ...@@ -1058,10 +1059,10 @@ class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingL
Examples: Examples:
```python ```python
>>> from transformers import LayoutLMTokenizer, TFLayoutLMForMaskedLM >>> from transformers import AutoTokenizer, TFLayoutLMForMaskedLM
>>> import tensorflow as tf >>> import tensorflow as tf
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = TFLayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = TFLayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "[MASK]"] >>> words = ["Hello", "[MASK]"]
...@@ -1181,10 +1182,10 @@ class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceC ...@@ -1181,10 +1182,10 @@ class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceC
Examples: Examples:
```python ```python
>>> from transformers import LayoutLMTokenizer, TFLayoutLMForSequenceClassification >>> from transformers import AutoTokenizer, TFLayoutLMForSequenceClassification
>>> import tensorflow as tf >>> import tensorflow as tf
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = TFLayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = TFLayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "world"] >>> words = ["Hello", "world"]
...@@ -1310,9 +1311,9 @@ class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassif ...@@ -1310,9 +1311,9 @@ class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassif
```python ```python
>>> import tensorflow as tf >>> import tensorflow as tf
>>> from transformers import LayoutLMTokenizer, TFLayoutLMForTokenClassification >>> from transformers import AutoTokenizer, TFLayoutLMForTokenClassification
>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = TFLayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased") >>> model = TFLayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
>>> words = ["Hello", "world"] >>> words = ["Hello", "world"]
...@@ -1377,3 +1378,150 @@ class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassif ...@@ -1377,3 +1378,150 @@ class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassif
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns) return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""
LayoutLM Model with a span classification head on top for extractive question-answering tasks such as
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to compute `span
start logits` and `span end logits`).
""",
LAYOUTLM_START_DOCSTRING,
)
class TFLayoutLMForQuestionAnswering(TFLayoutLMPreTrainedModel, TFQuestionAnsweringLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.layoutlm = TFLayoutLMMainLayer(config, add_pooling_layer=True, name="layoutlm")
self.qa_outputs = tf.keras.layers.Dense(
units=config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
)
@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids: Optional[TFModelInputType] = None,
bbox: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
r"""
start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
Returns:
Examples:
```python
>>> import tensorflow as tf
>>> from transformers import AutoTokenizer, TFLayoutLMForQuestionAnswering
>>> from datasets import load_dataset
>>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
>>> model = TFLayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa")
>>> dataset = load_dataset("nielsr/funsd", split="train")
>>> example = dataset[0]
>>> question = "what's his name?"
>>> words = example["words"]
>>> boxes = example["bboxes"]
>>> encoding = tokenizer(
... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="tf"
... )
>>> bbox = []
>>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
... if s == 1:
... bbox.append(boxes[w])
... elif i == tokenizer.sep_token_id:
... bbox.append([1000] * 4)
... else:
... bbox.append([0] * 4)
>>> encoding["bbox"] = tf.convert_to_tensor([bbox])
>>> word_ids = encoding.word_ids(0)
>>> outputs = model(**encoding)
>>> loss = outputs.loss
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
>>> start, end = word_ids[tf.math.argmax(start_scores, -1)[0]], word_ids[tf.math.argmax(end_scores, -1)[0]]
>>> print(" ".join(words[start : end + 1]))
M. Hamann P. Harper, P. Martinez
```"""
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
logits = self.qa_outputs(inputs=sequence_output)
start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
start_logits = tf.squeeze(input=start_logits, axis=-1)
end_logits = tf.squeeze(input=end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFQuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
)
...@@ -2469,6 +2469,13 @@ class LayoutLMForMaskedLM(metaclass=DummyObject): ...@@ -2469,6 +2469,13 @@ class LayoutLMForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class LayoutLMForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LayoutLMForSequenceClassification(metaclass=DummyObject): class LayoutLMForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -129,6 +129,13 @@ class TFLayoutLMForMaskedLM(metaclass=DummyObject): ...@@ -129,6 +129,13 @@ class TFLayoutLMForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFLayoutLMForQuestionAnswering(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFLayoutLMForSequenceClassification(metaclass=DummyObject): class TFLayoutLMForSequenceClassification(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
...@@ -147,6 +147,7 @@ _SPECIAL_SUPPORTED_MODELS = [ ...@@ -147,6 +147,7 @@ _SPECIAL_SUPPORTED_MODELS = [
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"Speech2Text2Decoder", "Speech2Text2Decoder",
"TrOCRDecoder", "TrOCRDecoder",
"LayoutLMForQuestionAnswering",
# TODO: add support for them as it should be quite easy to do so (small blocking issues). # TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering, # XLNetForQuestionAnswering,
] ]
...@@ -690,6 +691,7 @@ class HFTracer(Tracer): ...@@ -690,6 +691,7 @@ class HFTracer(Tracer):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [ elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
"LayoutLMForQuestionAnswering",
"XLNetForQuestionAnswering", "XLNetForQuestionAnswering",
]: ]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
......
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import unittest import unittest
from transformers import LayoutLMConfig, is_torch_available from transformers import LayoutLMConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -27,7 +28,11 @@ if is_torch_available(): ...@@ -27,7 +28,11 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
LayoutLMForMaskedLM, LayoutLMForMaskedLM,
LayoutLMForQuestionAnswering,
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
...@@ -181,6 +186,23 @@ class LayoutLMModelTester: ...@@ -181,6 +186,23 @@ class LayoutLMModelTester:
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_question_answering(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LayoutLMForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
bbox=bbox,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -211,6 +233,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -211,6 +233,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
LayoutLMForMaskedLM, LayoutLMForMaskedLM,
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMForQuestionAnswering,
) )
if is_torch_available() if is_torch_available()
else None else None
...@@ -246,6 +269,34 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -246,6 +269,34 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
elif model_class.__name__ == "LayoutLMForQuestionAnswering":
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
return inputs_dict
def prepare_layoutlm_batch_inputs(): def prepare_layoutlm_batch_inputs():
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on: # Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
...@@ -337,3 +388,18 @@ class LayoutLMModelIntegrationTest(unittest.TestCase): ...@@ -337,3 +388,18 @@ class LayoutLMModelIntegrationTest(unittest.TestCase):
logits = outputs.logits logits = outputs.logits
expected_shape = torch.Size((2, 25, 13)) expected_shape = torch.Size((2, 25, 13))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
@slow
def test_forward_pass_question_answering(self):
# initialize model with randomly initialized token classification head
model = LayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased").to(torch_device)
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
# forward pass
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
# test the shape of the logits
expected_shape = torch.Size((2, 25))
self.assertEqual(outputs.start_logits.shape, expected_shape)
self.assertEqual(outputs.end_logits.shape, expected_shape)
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import unittest import unittest
import numpy as np import numpy as np
from transformers import LayoutLMConfig, is_tf_available from transformers import LayoutLMConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -27,9 +29,15 @@ from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_at ...@@ -27,9 +29,15 @@ from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_at
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import (
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
)
from transformers.models.layoutlm.modeling_tf_layoutlm import ( from transformers.models.layoutlm.modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM, TFLayoutLMForMaskedLM,
TFLayoutLMForQuestionAnswering,
TFLayoutLMForSequenceClassification, TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification, TFLayoutLMForTokenClassification,
TFLayoutLMModel, TFLayoutLMModel,
...@@ -174,6 +182,15 @@ class TFLayoutLMModelTester: ...@@ -174,6 +182,15 @@ class TFLayoutLMModelTester:
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_question_answering(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLayoutLMForQuestionAnswering(config=config)
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -199,7 +216,13 @@ class TFLayoutLMModelTester: ...@@ -199,7 +216,13 @@ class TFLayoutLMModelTester:
class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase): class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFLayoutLMModel, TFLayoutLMForMaskedLM, TFLayoutLMForTokenClassification, TFLayoutLMForSequenceClassification) (
TFLayoutLMModel,
TFLayoutLMForMaskedLM,
TFLayoutLMForTokenClassification,
TFLayoutLMForSequenceClassification,
TFLayoutLMForQuestionAnswering,
)
if is_tf_available() if is_tf_available()
else () else ()
) )
...@@ -230,12 +253,34 @@ class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -230,12 +253,34 @@ class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFLayoutLMModel.from_pretrained(model_name) model = TFLayoutLMModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [
*get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
]:
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
elif model_class.__name__ == "TFLayoutLMForQuestionAnswering":
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict
def prepare_layoutlm_batch_inputs(): def prepare_layoutlm_batch_inputs():
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on: # Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
...@@ -316,3 +361,18 @@ class TFLayoutLMModelIntegrationTest(unittest.TestCase): ...@@ -316,3 +361,18 @@ class TFLayoutLMModelIntegrationTest(unittest.TestCase):
logits = outputs.logits logits = outputs.logits
expected_shape = tf.convert_to_tensor((2, 25, 13)) expected_shape = tf.convert_to_tensor((2, 25, 13))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
@slow
def test_forward_pass_question_answering(self):
# initialize model with randomly initialized token classification head
model = TFLayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased")
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
# forward pass
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
# test the shape of the logits
expected_shape = tf.convert_to_tensor((2, 25))
self.assertEqual(outputs.start_logits.shape, expected_shape)
self.assertEqual(outputs.end_logits.shape, expected_shape)
...@@ -161,6 +161,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -161,6 +161,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"FlavaImageModel", "FlavaImageModel",
"FlavaMultimodalModel", "FlavaMultimodalModel",
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"LayoutLMForQuestionAnswering",
"LukeForMaskedLM", "LukeForMaskedLM",
"LukeForEntityClassification", "LukeForEntityClassification",
"LukeForEntityPairClassification", "LukeForEntityPairClassification",
...@@ -178,6 +179,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -178,6 +179,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"RealmReader", "RealmReader",
"TFDPRReader", "TFDPRReader",
"TFGPT2DoubleHeadsModel", "TFGPT2DoubleHeadsModel",
"TFLayoutLMForQuestionAnswering",
"TFOpenAIGPTDoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel",
"TFRagModel", "TFRagModel",
"TFRagSequenceForGeneration", "TFRagSequenceForGeneration",
......
...@@ -38,6 +38,8 @@ src/transformers/models/glpn/modeling_glpn.py ...@@ -38,6 +38,8 @@ src/transformers/models/glpn/modeling_glpn.py
src/transformers/models/gpt2/modeling_gpt2.py src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gptj/modeling_gptj.py src/transformers/models/gptj/modeling_gptj.py
src/transformers/models/hubert/modeling_hubert.py src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/layoutlm/modeling_layoutlm.py
src/transformers/models/layoutlm/modeling_tf_layoutlm.py
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
src/transformers/models/layoutlmv3/modeling_tf_layoutlmv3.py src/transformers/models/layoutlmv3/modeling_tf_layoutlmv3.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment