"...googletest-release-1.12.1/googletest/samples/sample1.h" did not exist on "0a21fff9619a39d3b8105c065edb061f1f3b305d"
modeling_layoutxlm.py 5.12 KB
Newer Older
HHL's avatar
v  
HHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# coding=utf-8
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.utils import logging

from ..layoutlmv2 import LayoutLMv2ForRelationExtraction, LayoutLMv2ForTokenClassification, LayoutLMv2Model
from .configuration_layoutxlm import LayoutXLMConfig
from transformers.modeling_outputs import TokenClassifierOutput

logger = logging.get_logger(__name__)

LAYOUTXLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "layoutxlm-base",
    "layoutxlm-large",
]


class LayoutXLMForPretrain(LayoutLMv2ForTokenClassification):
    config_class = LayoutXLMConfig
    def __init__(self, config):
        super().__init__(config)
        self.num_tokens = config.num_tokens
        self.mvlm_cls = nn.Linear(config.hidden_size, config.num_tokens)
        self.tia_cls = nn.Linear(config.hidden_size, 2)
        self.tim_cls = nn.Linear(config.hidden_size, 2)

        total_alpha = config.mvlm_alpha + config.tia_alpha + config.tim_alpha
        self.mvlm_alpha = config.mvlm_alpha / total_alpha
        self.tia_alpha = config.tia_alpha / total_alpha
        self.tim_alpha = config.tim_alpha / total_alpha

    def forward(
        self,
        input_ids=None,
        bbox=None,
        image=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mvlm_labels=None,
        tia_labels=None,
        tim_labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # with torch.no_grad():
        outputs = self.layoutlmv2(
            input_ids=input_ids,
            bbox=bbox,
            image=image,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
            
        seq_length = input_ids.size(1)
        sequence_output, image_output = outputs[0][:, :seq_length], outputs[0][:, seq_length:]
        sequence_output = self.dropout(sequence_output)

        loss = None
        mvlm_logits = None
        tia_logits = None
        tim_logits = None

        if mvlm_labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            mvlm_logits = self.mvlm_cls(sequence_output)
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = mvlm_logits.view(-1, self.num_tokens)[active_loss]
                active_labels = mvlm_labels.view(-1)[active_loss]
                mvlm_loss = loss_fct(active_logits, active_labels)
            else:
                mvlm_loss = loss_fct(mvlm_logits.view(-1, self.num_tokens), mvlm_labels.view(-1))
            mvlm_loss = mvlm_loss.sum() / ((mvlm_labels != -100).sum() + 1e-5)
            if loss is not None:
                loss += self.mvlm_alpha * mvlm_loss
            else:
                loss = self.mvlm_alpha * mvlm_loss

        if tia_labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            tia_logits = self.tia_cls(sequence_output)
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = tia_logits.view(-1, 2)[active_loss]
                active_labels = tia_labels.view(-1)[active_loss]
                tia_loss = loss_fct(active_logits, active_labels)
            else:
                tia_loss = loss_fct(tia_logits.view(-1, 2), tia_labels.view(-1))
            tia_loss = tia_loss.sum() / ((tia_labels != -100).sum() + 1e-5)
            if loss is not None:
                loss += self.tia_alpha * tia_loss
            else:
                loss = self.tia_alpha * tia_loss

        if tim_labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            tim_logits = self.tim_cls(sequence_output[:, 0])
            tim_loss = loss_fct(tim_logits.view(-1, 2), tim_labels.view(-1))
            tim_loss = tim_loss.sum() / ((tim_labels != -100).sum() + 1e-5)
            if loss is not None:
                loss += self.tim_alpha * tim_loss
            else:
                loss = self.tim_alpha * tim_loss

        if not return_dict:
            output = (mvlm_logits.argmax(-1), tia_logits.argmax(-1), tim_logits.argmax(-1)) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=sequence_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class LayoutXLMModel(LayoutLMv2Model):
    config_class = LayoutXLMConfig


class LayoutXLMForTokenClassification(LayoutLMv2ForTokenClassification):
    config_class = LayoutXLMConfig

    
class LayoutXLMForRelationExtraction(LayoutLMv2ForRelationExtraction):
    config_class = LayoutXLMConfig