Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
GraphDoc_pytorch
Commits
f9b1a89a
Commit
f9b1a89a
authored
Dec 27, 2023
by
HHL
Browse files
v
parent
60e27226
Changes
238
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
493 additions
and
0 deletions
+493
-0
layoutlmft/models/layoutxlm/__pycache__/modeling_layoutxlm.cpython-38.pyc
...s/layoutxlm/__pycache__/modeling_layoutxlm.cpython-38.pyc
+0
-0
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm.cpython-37.pyc
...youtxlm/__pycache__/tokenization_layoutxlm.cpython-37.pyc
+0
-0
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm.cpython-38.pyc
...youtxlm/__pycache__/tokenization_layoutxlm.cpython-38.pyc
+0
-0
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm_fast.cpython-37.pyc
...lm/__pycache__/tokenization_layoutxlm_fast.cpython-37.pyc
+0
-0
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm_fast.cpython-38.pyc
...lm/__pycache__/tokenization_layoutxlm_fast.cpython-38.pyc
+0
-0
layoutlmft/models/layoutxlm/configuration_layoutxlm.py
layoutlmft/models/layoutxlm/configuration_layoutxlm.py
+86
-0
layoutlmft/models/layoutxlm/modeling_layoutxlm.py
layoutlmft/models/layoutxlm/modeling_layoutxlm.py
+139
-0
layoutlmft/models/layoutxlm/tokenization_layoutxlm.py
layoutlmft/models/layoutxlm/tokenization_layoutxlm.py
+33
-0
layoutlmft/models/layoutxlm/tokenization_layoutxlm_fast.py
layoutlmft/models/layoutxlm/tokenization_layoutxlm_fast.py
+43
-0
layoutlmft/models/model_args.py
layoutlmft/models/model_args.py
+34
-0
layoutlmft/modules/__init__.py
layoutlmft/modules/__init__.py
+0
-0
layoutlmft/modules/__pycache__/__init__.cpython-37.pyc
layoutlmft/modules/__pycache__/__init__.cpython-37.pyc
+0
-0
layoutlmft/modules/__pycache__/__init__.cpython-38.pyc
layoutlmft/modules/__pycache__/__init__.cpython-38.pyc
+0
-0
layoutlmft/modules/decoders/__init__.py
layoutlmft/modules/decoders/__init__.py
+0
-0
layoutlmft/modules/decoders/__pycache__/__init__.cpython-37.pyc
...lmft/modules/decoders/__pycache__/__init__.cpython-37.pyc
+0
-0
layoutlmft/modules/decoders/__pycache__/__init__.cpython-38.pyc
...lmft/modules/decoders/__pycache__/__init__.cpython-38.pyc
+0
-0
layoutlmft/modules/decoders/__pycache__/re.cpython-37.pyc
layoutlmft/modules/decoders/__pycache__/re.cpython-37.pyc
+0
-0
layoutlmft/modules/decoders/__pycache__/re.cpython-38.pyc
layoutlmft/modules/decoders/__pycache__/re.cpython-38.pyc
+0
-0
layoutlmft/modules/decoders/re.py
layoutlmft/modules/decoders/re.py
+154
-0
layoutlmft/trainers/__init__.py
layoutlmft/trainers/__init__.py
+4
-0
No files found.
layoutlmft/models/layoutxlm/__pycache__/modeling_layoutxlm.cpython-38.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm.cpython-37.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm.cpython-38.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm_fast.cpython-37.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/models/layoutxlm/__pycache__/tokenization_layoutxlm_fast.cpython-38.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/models/layoutxlm/configuration_layoutxlm.py
0 → 100644
View file @
f9b1a89a
# coding=utf-8
from
transformers.utils
import
logging
from
..layoutlmv2
import
LayoutLMv2Config
logger
=
logging
.
get_logger
(
__name__
)
LAYOUTXLM_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"layoutxlm-base"
:
"https://huggingface.co/layoutxlm-base/resolve/main/config.json"
,
"layoutxlm-large"
:
"https://huggingface.co/layoutxlm-large/resolve/main/config.json"
,
}
class
LayoutXLMConfig
(
LayoutLMv2Config
):
model_type
=
"layoutxlm"
def
__init__
(
self
,
vocab_size
=
30522
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
pad_token_id
=
0
,
gradient_checkpointing
=
False
,
max_2d_position_embeddings
=
1024
,
max_rel_pos
=
128
,
rel_pos_bins
=
32
,
fast_qkv
=
True
,
max_rel_2d_pos
=
256
,
rel_2d_pos_bins
=
64
,
convert_sync_batchnorm
=
True
,
image_feature_pool_shape
=
[
7
,
7
,
256
],
coordinate_size
=
128
,
shape_size
=
128
,
has_relative_attention_bias
=
True
,
has_spatial_attention_bias
=
True
,
has_visual_segment_embedding
=
False
,
num_tokens
=
2
,
mvlm_alpha
=
4
,
tia_alpha
=
3
,
tim_alpha
=
3
,
**
kwargs
):
super
().
__init__
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
intermediate_size
,
hidden_act
=
hidden_act
,
hidden_dropout_prob
=
hidden_dropout_prob
,
attention_probs_dropout_prob
=
attention_probs_dropout_prob
,
max_position_embeddings
=
max_position_embeddings
,
type_vocab_size
=
type_vocab_size
,
initializer_range
=
initializer_range
,
layer_norm_eps
=
layer_norm_eps
,
pad_token_id
=
pad_token_id
,
gradient_checkpointing
=
gradient_checkpointing
,
**
kwargs
,
)
self
.
max_2d_position_embeddings
=
max_2d_position_embeddings
self
.
max_rel_pos
=
max_rel_pos
self
.
rel_pos_bins
=
rel_pos_bins
self
.
fast_qkv
=
fast_qkv
self
.
max_rel_2d_pos
=
max_rel_2d_pos
self
.
rel_2d_pos_bins
=
rel_2d_pos_bins
self
.
convert_sync_batchnorm
=
convert_sync_batchnorm
self
.
image_feature_pool_shape
=
image_feature_pool_shape
self
.
coordinate_size
=
coordinate_size
self
.
shape_size
=
shape_size
self
.
has_relative_attention_bias
=
has_relative_attention_bias
self
.
has_spatial_attention_bias
=
has_spatial_attention_bias
self
.
has_visual_segment_embedding
=
has_visual_segment_embedding
self
.
num_tokens
=
num_tokens
self
.
mvlm_alpha
=
mvlm_alpha
self
.
tia_alpha
=
tia_alpha
self
.
tim_alpha
=
tim_alpha
layoutlmft/models/layoutxlm/modeling_layoutxlm.py
0 → 100644
View file @
f9b1a89a
# coding=utf-8
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
transformers.utils
import
logging
from
..layoutlmv2
import
LayoutLMv2ForRelationExtraction
,
LayoutLMv2ForTokenClassification
,
LayoutLMv2Model
from
.configuration_layoutxlm
import
LayoutXLMConfig
from
transformers.modeling_outputs
import
TokenClassifierOutput
logger
=
logging
.
get_logger
(
__name__
)
LAYOUTXLM_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"layoutxlm-base"
,
"layoutxlm-large"
,
]
class
LayoutXLMForPretrain
(
LayoutLMv2ForTokenClassification
):
config_class
=
LayoutXLMConfig
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
num_tokens
=
config
.
num_tokens
self
.
mvlm_cls
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_tokens
)
self
.
tia_cls
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
tim_cls
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
total_alpha
=
config
.
mvlm_alpha
+
config
.
tia_alpha
+
config
.
tim_alpha
self
.
mvlm_alpha
=
config
.
mvlm_alpha
/
total_alpha
self
.
tia_alpha
=
config
.
tia_alpha
/
total_alpha
self
.
tim_alpha
=
config
.
tim_alpha
/
total_alpha
def
forward
(
self
,
input_ids
=
None
,
bbox
=
None
,
image
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
mvlm_labels
=
None
,
tia_labels
=
None
,
tim_labels
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
):
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# with torch.no_grad():
outputs
=
self
.
layoutlmv2
(
input_ids
=
input_ids
,
bbox
=
bbox
,
image
=
image
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
seq_length
=
input_ids
.
size
(
1
)
sequence_output
,
image_output
=
outputs
[
0
][:,
:
seq_length
],
outputs
[
0
][:,
seq_length
:]
sequence_output
=
self
.
dropout
(
sequence_output
)
loss
=
None
mvlm_logits
=
None
tia_logits
=
None
tim_logits
=
None
if
mvlm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
,
reduction
=
'none'
)
mvlm_logits
=
self
.
mvlm_cls
(
sequence_output
)
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
view
(
-
1
)
==
1
active_logits
=
mvlm_logits
.
view
(
-
1
,
self
.
num_tokens
)[
active_loss
]
active_labels
=
mvlm_labels
.
view
(
-
1
)[
active_loss
]
mvlm_loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
mvlm_loss
=
loss_fct
(
mvlm_logits
.
view
(
-
1
,
self
.
num_tokens
),
mvlm_labels
.
view
(
-
1
))
mvlm_loss
=
mvlm_loss
.
sum
()
/
((
mvlm_labels
!=
-
100
).
sum
()
+
1e-5
)
if
loss
is
not
None
:
loss
+=
self
.
mvlm_alpha
*
mvlm_loss
else
:
loss
=
self
.
mvlm_alpha
*
mvlm_loss
if
tia_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
,
reduction
=
'none'
)
tia_logits
=
self
.
tia_cls
(
sequence_output
)
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
view
(
-
1
)
==
1
active_logits
=
tia_logits
.
view
(
-
1
,
2
)[
active_loss
]
active_labels
=
tia_labels
.
view
(
-
1
)[
active_loss
]
tia_loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
tia_loss
=
loss_fct
(
tia_logits
.
view
(
-
1
,
2
),
tia_labels
.
view
(
-
1
))
tia_loss
=
tia_loss
.
sum
()
/
((
tia_labels
!=
-
100
).
sum
()
+
1e-5
)
if
loss
is
not
None
:
loss
+=
self
.
tia_alpha
*
tia_loss
else
:
loss
=
self
.
tia_alpha
*
tia_loss
if
tim_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
,
reduction
=
'none'
)
tim_logits
=
self
.
tim_cls
(
sequence_output
[:,
0
])
tim_loss
=
loss_fct
(
tim_logits
.
view
(
-
1
,
2
),
tim_labels
.
view
(
-
1
))
tim_loss
=
tim_loss
.
sum
()
/
((
tim_labels
!=
-
100
).
sum
()
+
1e-5
)
if
loss
is
not
None
:
loss
+=
self
.
tim_alpha
*
tim_loss
else
:
loss
=
self
.
tim_alpha
*
tim_loss
if
not
return_dict
:
output
=
(
mvlm_logits
.
argmax
(
-
1
),
tia_logits
.
argmax
(
-
1
),
tim_logits
.
argmax
(
-
1
))
+
outputs
[
2
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
TokenClassifierOutput
(
loss
=
loss
,
logits
=
sequence_output
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
)
class
LayoutXLMModel
(
LayoutLMv2Model
):
config_class
=
LayoutXLMConfig
class
LayoutXLMForTokenClassification
(
LayoutLMv2ForTokenClassification
):
config_class
=
LayoutXLMConfig
class
LayoutXLMForRelationExtraction
(
LayoutLMv2ForRelationExtraction
):
config_class
=
LayoutXLMConfig
\ No newline at end of file
layoutlmft/models/layoutxlm/tokenization_layoutxlm.py
0 → 100644
View file @
f9b1a89a
# coding=utf-8
from
transformers
import
XLMRobertaTokenizer
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
SPIECE_UNDERLINE
=
"▁"
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"layoutxlm-base"
:
"https://huggingface.co/layoutxlm-base/resolve/main/sentencepiece.bpe.model"
,
"layoutxlm-large"
:
"https://huggingface.co/layoutxlm-large/resolve/main/sentencepiece.bpe.model"
,
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
"layoutxlm-base"
:
512
,
"layoutxlm-large"
:
512
,
}
class
LayoutXLMTokenizer
(
XLMRobertaTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
self
,
model_max_length
=
512
,
**
kwargs
):
super
().
__init__
(
model_max_length
=
model_max_length
,
**
kwargs
)
layoutlmft/models/layoutxlm/tokenization_layoutxlm_fast.py
0 → 100644
View file @
f9b1a89a
# coding=utf-8
from
transformers
import
XLMRobertaTokenizerFast
from
transformers.file_utils
import
is_sentencepiece_available
from
transformers.utils
import
logging
if
is_sentencepiece_available
():
from
.tokenization_layoutxlm
import
LayoutXLMTokenizer
else
:
LayoutXLMTokenizer
=
None
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
,
"tokenizer_file"
:
"tokenizer.json"
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"layoutxlm-base"
:
"https://huggingface.co/layoutxlm-base/resolve/main/sentencepiece.bpe.model"
,
"layoutxlm-large"
:
"https://huggingface.co/layoutxlm-large/resolve/main/sentencepiece.bpe.model"
,
},
"tokenizer_file"
:
{
"layoutxlm-base"
:
"https://huggingface.co/layoutxlm-base/resolve/main/tokenizer.json"
,
"layoutxlm-large"
:
"https://huggingface.co/layoutxlm-large/resolve/main/tokenizer.json"
,
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
"layoutxlm-base"
:
512
,
"layoutxlm-large"
:
512
,
}
class
LayoutXLMTokenizerFast
(
XLMRobertaTokenizerFast
):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
slow_tokenizer_class
=
LayoutXLMTokenizer
def
__init__
(
self
,
model_max_length
=
512
,
**
kwargs
):
super
().
__init__
(
model_max_length
=
model_max_length
,
**
kwargs
)
layoutlmft/models/model_args.py
0 → 100644
View file @
f9b1a89a
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
@
dataclass
class
ModelArguments
:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path
:
str
=
field
(
metadata
=
{
"help"
:
"Path to pretrained model or model identifier from huggingface.co/models"
}
)
config_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Pretrained config name or path if not the same as model_name"
}
)
tokenizer_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Pretrained tokenizer name or path if not the same as model_name"
}
)
cache_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Where do you want to store the pretrained models downloaded from huggingface.co"
},
)
model_revision
:
str
=
field
(
default
=
"main"
,
metadata
=
{
"help"
:
"The specific model version to use (can be a branch name, tag name or commit id)."
},
)
use_auth_token
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
layoutlmft/modules/__init__.py
0 → 100644
View file @
f9b1a89a
layoutlmft/modules/__pycache__/__init__.cpython-37.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/modules/__pycache__/__init__.cpython-38.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/modules/decoders/__init__.py
0 → 100644
View file @
f9b1a89a
layoutlmft/modules/decoders/__pycache__/__init__.cpython-37.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/modules/decoders/__pycache__/__init__.cpython-38.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/modules/decoders/__pycache__/re.cpython-37.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/modules/decoders/__pycache__/re.cpython-38.pyc
0 → 100644
View file @
f9b1a89a
File added
layoutlmft/modules/decoders/re.py
0 → 100644
View file @
f9b1a89a
import
copy
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
class
BiaffineAttention
(
torch
.
nn
.
Module
):
"""Implements a biaffine attention operator for binary relation classification.
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
as a classifier for binary relation classification.
Args:
in_features (int): The size of the feature dimension of the inputs.
out_features (int): The size of the feature dimension of the output.
Shape:
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
additional dimensisons.
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
additional dimensions.
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
of additional dimensions.
Examples:
>>> batch_size, in_features, out_features = 32, 100, 4
>>> biaffine_attention = BiaffineAttention(in_features, out_features)
>>> x_1 = torch.randn(batch_size, in_features)
>>> x_2 = torch.randn(batch_size, in_features)
>>> output = biaffine_attention(x_1, x_2)
>>> print(output.size())
torch.Size([32, 4])
"""
def
__init__
(
self
,
in_features
,
out_features
):
super
(
BiaffineAttention
,
self
).
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
bilinear
=
torch
.
nn
.
Bilinear
(
in_features
,
in_features
,
out_features
,
bias
=
False
)
self
.
linear
=
torch
.
nn
.
Linear
(
2
*
in_features
,
out_features
,
bias
=
True
)
self
.
reset_parameters
()
def
forward
(
self
,
x_1
,
x_2
):
return
self
.
bilinear
(
x_1
,
x_2
)
+
self
.
linear
(
torch
.
cat
((
x_1
,
x_2
),
dim
=-
1
))
def
reset_parameters
(
self
):
self
.
bilinear
.
reset_parameters
()
self
.
linear
.
reset_parameters
()
class
REDecoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
entity_emb
=
nn
.
Embedding
(
3
,
config
.
hidden_size
,
scale_grad_by_freq
=
True
)
projection
=
nn
.
Sequential
(
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
),
nn
.
ReLU
(),
nn
.
Dropout
(
config
.
hidden_dropout_prob
),
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
//
2
),
nn
.
ReLU
(),
nn
.
Dropout
(
config
.
hidden_dropout_prob
),
)
self
.
ffnn_head
=
copy
.
deepcopy
(
projection
)
self
.
ffnn_tail
=
copy
.
deepcopy
(
projection
)
self
.
rel_classifier
=
BiaffineAttention
(
config
.
hidden_size
//
2
,
2
)
self
.
loss_fct
=
CrossEntropyLoss
()
def
build_relation
(
self
,
relations
,
entities
):
batch_size
=
len
(
relations
)
new_relations
=
[]
for
b
in
range
(
batch_size
):
if
len
(
entities
[
b
][
"start"
])
<=
2
:
entities
[
b
]
=
{
"end"
:
[
1
,
1
],
"label"
:
[
0
,
0
],
"start"
:
[
0
,
0
]}
all_possible_relations
=
set
(
[
(
i
,
j
)
for
i
in
range
(
len
(
entities
[
b
][
"label"
]))
for
j
in
range
(
i
+
1
,
len
(
entities
[
b
][
"label"
]))
if
entities
[
b
][
"label"
][
i
]
==
1
and
entities
[
b
][
"label"
][
j
]
==
2
]
)
if
len
(
all_possible_relations
)
==
0
:
all_possible_relations
=
set
([(
0
,
1
)])
positive_relations
=
set
(
list
(
zip
(
relations
[
b
][
"head"
],
relations
[
b
][
"tail"
])))
negative_relations
=
all_possible_relations
-
positive_relations
positive_relations
=
set
([
i
for
i
in
positive_relations
if
i
in
all_possible_relations
])
reordered_relations
=
list
(
positive_relations
)
+
list
(
negative_relations
)
relation_per_doc
=
{
"head"
:
[],
"tail"
:
[],
"label"
:
[]}
relation_per_doc
[
"head"
]
=
[
i
[
0
]
for
i
in
reordered_relations
]
relation_per_doc
[
"tail"
]
=
[
i
[
1
]
for
i
in
reordered_relations
]
relation_per_doc
[
"label"
]
=
[
1
]
*
len
(
positive_relations
)
+
[
0
]
*
(
len
(
reordered_relations
)
-
len
(
positive_relations
)
)
assert
len
(
relation_per_doc
[
"head"
])
!=
0
new_relations
.
append
(
relation_per_doc
)
return
new_relations
,
entities
def
get_predicted_relations
(
self
,
logits
,
relations
,
entities
):
pred_relations
=
[]
for
i
,
pred_label
in
enumerate
(
logits
.
argmax
(
-
1
)):
if
pred_label
!=
1
:
continue
rel
=
{}
rel
[
"head_id"
]
=
relations
[
"head"
][
i
]
rel
[
"head"
]
=
(
entities
[
"start"
][
rel
[
"head_id"
]],
entities
[
"end"
][
rel
[
"head_id"
]])
rel
[
"head_type"
]
=
entities
[
"label"
][
rel
[
"head_id"
]]
rel
[
"tail_id"
]
=
relations
[
"tail"
][
i
]
rel
[
"tail"
]
=
(
entities
[
"start"
][
rel
[
"tail_id"
]],
entities
[
"end"
][
rel
[
"tail_id"
]])
rel
[
"tail_type"
]
=
entities
[
"label"
][
rel
[
"tail_id"
]]
rel
[
"type"
]
=
1
pred_relations
.
append
(
rel
)
return
pred_relations
def
forward
(
self
,
hidden_states
,
entities
,
relations
):
batch_size
,
max_n_words
,
context_dim
=
hidden_states
.
size
()
device
=
hidden_states
.
device
relations
,
entities
=
self
.
build_relation
(
relations
,
entities
)
loss
=
0
all_pred_relations
=
[]
for
b
in
range
(
batch_size
):
head_entities
=
torch
.
tensor
(
relations
[
b
][
"head"
],
device
=
device
)
tail_entities
=
torch
.
tensor
(
relations
[
b
][
"tail"
],
device
=
device
)
relation_labels
=
torch
.
tensor
(
relations
[
b
][
"label"
],
device
=
device
)
entities_start_index
=
torch
.
tensor
(
entities
[
b
][
"start"
],
device
=
device
)
entities_labels
=
torch
.
tensor
(
entities
[
b
][
"label"
],
device
=
device
)
head_index
=
entities_start_index
[
head_entities
]
head_label
=
entities_labels
[
head_entities
]
head_label_repr
=
self
.
entity_emb
(
head_label
)
tail_index
=
entities_start_index
[
tail_entities
]
tail_label
=
entities_labels
[
tail_entities
]
tail_label_repr
=
self
.
entity_emb
(
tail_label
)
head_repr
=
torch
.
cat
(
(
hidden_states
[
b
][
head_index
],
head_label_repr
),
dim
=-
1
,
)
tail_repr
=
torch
.
cat
(
(
hidden_states
[
b
][
tail_index
],
tail_label_repr
),
dim
=-
1
,
)
heads
=
self
.
ffnn_head
(
head_repr
)
tails
=
self
.
ffnn_tail
(
tail_repr
)
logits
=
self
.
rel_classifier
(
heads
,
tails
)
loss
+=
self
.
loss_fct
(
logits
,
relation_labels
)
pred_relations
=
self
.
get_predicted_relations
(
logits
,
relations
[
b
],
entities
[
b
])
all_pred_relations
.
append
(
pred_relations
)
return
loss
,
all_pred_relations
layoutlmft/trainers/__init__.py
0 → 100644
View file @
f9b1a89a
from
.huaweikie_trainer
import
HuaweiKIETrainer
from
.funsd_trainer
import
FunsdTrainer
from
.xfun_trainer
import
XfunReTrainer
,
XfunSerTrainer
from
.pre_trainer
import
PreTrainer
\ No newline at end of file
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment