Commit f9b1a89a authored by HHL's avatar HHL
Browse files

v

parent 60e27226
# coding=utf-8
from transformers.utils import logging
from ..layoutlmv2 import LayoutLMv2Config
logger = logging.get_logger(__name__)
LAYOUTXLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"layoutxlm-base": "https://huggingface.co/layoutxlm-base/resolve/main/config.json",
"layoutxlm-large": "https://huggingface.co/layoutxlm-large/resolve/main/config.json",
}
class LayoutXLMConfig(LayoutLMv2Config):
model_type = "layoutxlm"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
max_2d_position_embeddings=1024,
max_rel_pos=128,
rel_pos_bins=32,
fast_qkv=True,
max_rel_2d_pos=256,
rel_2d_pos_bins=64,
convert_sync_batchnorm=True,
image_feature_pool_shape=[7, 7, 256],
coordinate_size=128,
shape_size=128,
has_relative_attention_bias=True,
has_spatial_attention_bias=True,
has_visual_segment_embedding=False,
num_tokens=2,
mvlm_alpha=4,
tia_alpha=3,
tim_alpha=3,
**kwargs
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
pad_token_id=pad_token_id,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
self.max_2d_position_embeddings = max_2d_position_embeddings
self.max_rel_pos = max_rel_pos
self.rel_pos_bins = rel_pos_bins
self.fast_qkv = fast_qkv
self.max_rel_2d_pos = max_rel_2d_pos
self.rel_2d_pos_bins = rel_2d_pos_bins
self.convert_sync_batchnorm = convert_sync_batchnorm
self.image_feature_pool_shape = image_feature_pool_shape
self.coordinate_size = coordinate_size
self.shape_size = shape_size
self.has_relative_attention_bias = has_relative_attention_bias
self.has_spatial_attention_bias = has_spatial_attention_bias
self.has_visual_segment_embedding = has_visual_segment_embedding
self.num_tokens = num_tokens
self.mvlm_alpha = mvlm_alpha
self.tia_alpha = tia_alpha
self.tim_alpha = tim_alpha
# coding=utf-8
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.utils import logging
from ..layoutlmv2 import LayoutLMv2ForRelationExtraction, LayoutLMv2ForTokenClassification, LayoutLMv2Model
from .configuration_layoutxlm import LayoutXLMConfig
from transformers.modeling_outputs import TokenClassifierOutput
logger = logging.get_logger(__name__)
LAYOUTXLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"layoutxlm-base",
"layoutxlm-large",
]
class LayoutXLMForPretrain(LayoutLMv2ForTokenClassification):
config_class = LayoutXLMConfig
def __init__(self, config):
super().__init__(config)
self.num_tokens = config.num_tokens
self.mvlm_cls = nn.Linear(config.hidden_size, config.num_tokens)
self.tia_cls = nn.Linear(config.hidden_size, 2)
self.tim_cls = nn.Linear(config.hidden_size, 2)
total_alpha = config.mvlm_alpha + config.tia_alpha + config.tim_alpha
self.mvlm_alpha = config.mvlm_alpha / total_alpha
self.tia_alpha = config.tia_alpha / total_alpha
self.tim_alpha = config.tim_alpha / total_alpha
def forward(
self,
input_ids=None,
bbox=None,
image=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
mvlm_labels=None,
tia_labels=None,
tim_labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# with torch.no_grad():
outputs = self.layoutlmv2(
input_ids=input_ids,
bbox=bbox,
image=image,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
seq_length = input_ids.size(1)
sequence_output, image_output = outputs[0][:, :seq_length], outputs[0][:, seq_length:]
sequence_output = self.dropout(sequence_output)
loss = None
mvlm_logits = None
tia_logits = None
tim_logits = None
if mvlm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
mvlm_logits = self.mvlm_cls(sequence_output)
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = mvlm_logits.view(-1, self.num_tokens)[active_loss]
active_labels = mvlm_labels.view(-1)[active_loss]
mvlm_loss = loss_fct(active_logits, active_labels)
else:
mvlm_loss = loss_fct(mvlm_logits.view(-1, self.num_tokens), mvlm_labels.view(-1))
mvlm_loss = mvlm_loss.sum() / ((mvlm_labels != -100).sum() + 1e-5)
if loss is not None:
loss += self.mvlm_alpha * mvlm_loss
else:
loss = self.mvlm_alpha * mvlm_loss
if tia_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
tia_logits = self.tia_cls(sequence_output)
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = tia_logits.view(-1, 2)[active_loss]
active_labels = tia_labels.view(-1)[active_loss]
tia_loss = loss_fct(active_logits, active_labels)
else:
tia_loss = loss_fct(tia_logits.view(-1, 2), tia_labels.view(-1))
tia_loss = tia_loss.sum() / ((tia_labels != -100).sum() + 1e-5)
if loss is not None:
loss += self.tia_alpha * tia_loss
else:
loss = self.tia_alpha * tia_loss
if tim_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
tim_logits = self.tim_cls(sequence_output[:, 0])
tim_loss = loss_fct(tim_logits.view(-1, 2), tim_labels.view(-1))
tim_loss = tim_loss.sum() / ((tim_labels != -100).sum() + 1e-5)
if loss is not None:
loss += self.tim_alpha * tim_loss
else:
loss = self.tim_alpha * tim_loss
if not return_dict:
output = (mvlm_logits.argmax(-1), tia_logits.argmax(-1), tim_logits.argmax(-1)) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=sequence_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class LayoutXLMModel(LayoutLMv2Model):
config_class = LayoutXLMConfig
class LayoutXLMForTokenClassification(LayoutLMv2ForTokenClassification):
config_class = LayoutXLMConfig
class LayoutXLMForRelationExtraction(LayoutLMv2ForRelationExtraction):
config_class = LayoutXLMConfig
\ No newline at end of file
# coding=utf-8
from transformers import XLMRobertaTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"layoutxlm-base": "https://huggingface.co/layoutxlm-base/resolve/main/sentencepiece.bpe.model",
"layoutxlm-large": "https://huggingface.co/layoutxlm-large/resolve/main/sentencepiece.bpe.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"layoutxlm-base": 512,
"layoutxlm-large": 512,
}
class LayoutXLMTokenizer(XLMRobertaTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, model_max_length=512, **kwargs):
super().__init__(model_max_length=model_max_length, **kwargs)
# coding=utf-8
from transformers import XLMRobertaTokenizerFast
from transformers.file_utils import is_sentencepiece_available
from transformers.utils import logging
if is_sentencepiece_available():
from .tokenization_layoutxlm import LayoutXLMTokenizer
else:
LayoutXLMTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"layoutxlm-base": "https://huggingface.co/layoutxlm-base/resolve/main/sentencepiece.bpe.model",
"layoutxlm-large": "https://huggingface.co/layoutxlm-large/resolve/main/sentencepiece.bpe.model",
},
"tokenizer_file": {
"layoutxlm-base": "https://huggingface.co/layoutxlm-base/resolve/main/tokenizer.json",
"layoutxlm-large": "https://huggingface.co/layoutxlm-large/resolve/main/tokenizer.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"layoutxlm-base": 512,
"layoutxlm-large": 512,
}
class LayoutXLMTokenizerFast(XLMRobertaTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = LayoutXLMTokenizer
def __init__(self, model_max_length=512, **kwargs):
super().__init__(model_max_length=model_max_length, **kwargs)
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
import copy
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
class BiaffineAttention(torch.nn.Module):
"""Implements a biaffine attention operator for binary relation classification.
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
as a classifier for binary relation classification.
Args:
in_features (int): The size of the feature dimension of the inputs.
out_features (int): The size of the feature dimension of the output.
Shape:
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
additional dimensisons.
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
additional dimensions.
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
of additional dimensions.
Examples:
>>> batch_size, in_features, out_features = 32, 100, 4
>>> biaffine_attention = BiaffineAttention(in_features, out_features)
>>> x_1 = torch.randn(batch_size, in_features)
>>> x_2 = torch.randn(batch_size, in_features)
>>> output = biaffine_attention(x_1, x_2)
>>> print(output.size())
torch.Size([32, 4])
"""
def __init__(self, in_features, out_features):
super(BiaffineAttention, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)
self.reset_parameters()
def forward(self, x_1, x_2):
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))
def reset_parameters(self):
self.bilinear.reset_parameters()
self.linear.reset_parameters()
class REDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True)
projection = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
)
self.ffnn_head = copy.deepcopy(projection)
self.ffnn_tail = copy.deepcopy(projection)
self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2)
self.loss_fct = CrossEntropyLoss()
def build_relation(self, relations, entities):
batch_size = len(relations)
new_relations = []
for b in range(batch_size):
if len(entities[b]["start"]) <= 2:
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
all_possible_relations = set(
[
(i, j)
for i in range(len(entities[b]["label"]))
for j in range(i + 1, len(entities[b]["label"]))
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
]
)
if len(all_possible_relations) == 0:
all_possible_relations = set([(0, 1)])
positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"])))
negative_relations = all_possible_relations - positive_relations
positive_relations = set([i for i in positive_relations if i in all_possible_relations])
reordered_relations = list(positive_relations) + list(negative_relations)
relation_per_doc = {"head": [], "tail": [], "label": []}
relation_per_doc["head"] = [i[0] for i in reordered_relations]
relation_per_doc["tail"] = [i[1] for i in reordered_relations]
relation_per_doc["label"] = [1] * len(positive_relations) + [0] * (
len(reordered_relations) - len(positive_relations)
)
assert len(relation_per_doc["head"]) != 0
new_relations.append(relation_per_doc)
return new_relations, entities
def get_predicted_relations(self, logits, relations, entities):
pred_relations = []
for i, pred_label in enumerate(logits.argmax(-1)):
if pred_label != 1:
continue
rel = {}
rel["head_id"] = relations["head"][i]
rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]])
rel["head_type"] = entities["label"][rel["head_id"]]
rel["tail_id"] = relations["tail"][i]
rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]])
rel["tail_type"] = entities["label"][rel["tail_id"]]
rel["type"] = 1
pred_relations.append(rel)
return pred_relations
def forward(self, hidden_states, entities, relations):
batch_size, max_n_words, context_dim = hidden_states.size()
device = hidden_states.device
relations, entities = self.build_relation(relations, entities)
loss = 0
all_pred_relations = []
for b in range(batch_size):
head_entities = torch.tensor(relations[b]["head"], device=device)
tail_entities = torch.tensor(relations[b]["tail"], device=device)
relation_labels = torch.tensor(relations[b]["label"], device=device)
entities_start_index = torch.tensor(entities[b]["start"], device=device)
entities_labels = torch.tensor(entities[b]["label"], device=device)
head_index = entities_start_index[head_entities]
head_label = entities_labels[head_entities]
head_label_repr = self.entity_emb(head_label)
tail_index = entities_start_index[tail_entities]
tail_label = entities_labels[tail_entities]
tail_label_repr = self.entity_emb(tail_label)
head_repr = torch.cat(
(hidden_states[b][head_index], head_label_repr),
dim=-1,
)
tail_repr = torch.cat(
(hidden_states[b][tail_index], tail_label_repr),
dim=-1,
)
heads = self.ffnn_head(head_repr)
tails = self.ffnn_tail(tail_repr)
logits = self.rel_classifier(heads, tails)
loss += self.loss_fct(logits, relation_labels)
pred_relations = self.get_predicted_relations(logits, relations[b], entities[b])
all_pred_relations.append(pred_relations)
return loss, all_pred_relations
from .huaweikie_trainer import HuaweiKIETrainer
from .funsd_trainer import FunsdTrainer
from .xfun_trainer import XfunReTrainer, XfunSerTrainer
from .pre_trainer import PreTrainer
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment